1use std::collections::{HashMap, HashSet};
28use std::fmt::Write;
29
30use tensorlogic_ir::{DotExportOptions, EinsumGraph, OpType};
31
32#[derive(Debug, Clone)]
38pub struct VisualizationConfig {
39 pub show_details: bool,
41 pub show_shapes: bool,
43 pub max_depth: usize,
45 pub use_color: bool,
47 pub indent: String,
49 pub show_tensor_ids: bool,
51 pub show_node_ids: bool,
53 pub horizontal_layout: bool,
55 pub cluster_by_operation: bool,
57}
58
59impl Default for VisualizationConfig {
60 fn default() -> Self {
61 VisualizationConfig {
62 show_details: true,
63 show_shapes: true,
64 max_depth: 0,
65 use_color: true,
66 indent: " ".to_string(),
67 show_tensor_ids: false,
68 show_node_ids: true,
69 horizontal_layout: false,
70 cluster_by_operation: false,
71 }
72 }
73}
74
75impl VisualizationConfig {
76 pub fn new() -> Self {
78 Self::default()
79 }
80
81 pub fn with_details(mut self, v: bool) -> Self {
83 self.show_details = v;
84 self
85 }
86
87 pub fn with_shapes(mut self, v: bool) -> Self {
89 self.show_shapes = v;
90 self
91 }
92
93 pub fn with_max_depth(mut self, d: usize) -> Self {
95 self.max_depth = d;
96 self
97 }
98
99 pub fn with_color(mut self, v: bool) -> Self {
101 self.use_color = v;
102 self
103 }
104
105 pub fn with_tensor_ids(mut self, v: bool) -> Self {
107 self.show_tensor_ids = v;
108 self
109 }
110
111 pub fn with_node_ids(mut self, v: bool) -> Self {
113 self.show_node_ids = v;
114 self
115 }
116
117 pub fn with_horizontal_layout(mut self, v: bool) -> Self {
119 self.horizontal_layout = v;
120 self
121 }
122
123 pub fn with_clustering(mut self, v: bool) -> Self {
125 self.cluster_by_operation = v;
126 self
127 }
128
129 pub fn minimal() -> Self {
131 VisualizationConfig {
132 show_details: false,
133 show_shapes: false,
134 show_tensor_ids: false,
135 show_node_ids: false,
136 ..Self::default()
137 }
138 }
139
140 fn to_dot_options(&self) -> DotExportOptions {
142 DotExportOptions {
143 show_tensor_ids: self.show_tensor_ids,
144 show_node_ids: self.show_node_ids,
145 show_metadata: self.show_details,
146 show_shapes: self.show_shapes,
147 cluster_by_operation: self.cluster_by_operation,
148 horizontal_layout: self.horizontal_layout,
149 highlight_tensors: Vec::new(),
150 highlight_nodes: Vec::new(),
151 }
152 }
153}
154
155pub struct DotExporter;
165
166impl DotExporter {
167 pub fn export(graph: &EinsumGraph, config: &VisualizationConfig) -> String {
169 let options = config.to_dot_options();
170 let dot = tensorlogic_ir::export_to_dot_with_options(graph, &options);
171
172 if config.use_color {
173 dot
174 } else {
175 Self::strip_fill_colors(&dot)
176 }
177 }
178
179 fn strip_fill_colors(dot: &str) -> String {
182 let mut result = String::with_capacity(dot.len());
183 for line in dot.lines() {
184 let cleaned = line
185 .replace(", style=filled", "")
186 .replace("style=filled, ", "")
187 .replace("style=filled", "");
188
189 let cleaned = strip_attr(&cleaned, "fillcolor=");
191 let cleaned = cleaned.replace(", ];", "];").replace(",];", "];");
193 let _ = writeln!(result, "{}", cleaned);
194 }
195 result
196 }
197}
198
199fn strip_attr(line: &str, prefix: &str) -> String {
201 if let Some(start) = line.find(prefix) {
202 let before = &line[..start];
203 let after_key = &line[start + prefix.len()..];
204 let end = after_key
206 .find([',', ';', ']', ' '])
207 .unwrap_or(after_key.len());
208 let rest = &after_key[end..];
209 let rest = rest.strip_prefix(", ").unwrap_or(rest);
211 let rest = rest.strip_prefix(',').unwrap_or(rest);
212 format!("{}{}", before.trim_end_matches(", "), rest)
213 } else {
214 line.to_string()
215 }
216}
217
218pub fn write_dot_file(
220 path: &std::path::Path,
221 graph: &EinsumGraph,
222 config: &VisualizationConfig,
223) -> std::io::Result<()> {
224 let dot = DotExporter::export(graph, config);
225 std::fs::write(path, dot)
226}
227
228pub struct AsciiRenderer;
234
235impl AsciiRenderer {
236 pub fn render(graph: &EinsumGraph, config: &VisualizationConfig) -> String {
238 let mut out = String::new();
239
240 let _ = writeln!(out, "=== EinsumGraph ===");
241 let _ = writeln!(out, "Nodes: {}", graph.nodes.len());
242 let _ = writeln!(
243 out,
244 "Tensors: {} ({} inputs, {} outputs)",
245 graph.tensors.len(),
246 graph.inputs.len(),
247 graph.outputs.len()
248 );
249
250 if !graph.outputs.is_empty() {
252 let names: Vec<&str> = graph
253 .outputs
254 .iter()
255 .filter_map(|&idx| graph.tensors.get(idx).map(|s| s.as_str()))
256 .collect();
257 let _ = writeln!(out, "Outputs: [{}]", names.join(", "));
258 }
259
260 let _ = writeln!(out);
261
262 let depth_limit = if config.max_depth == 0 {
264 usize::MAX
265 } else {
266 config.max_depth
267 };
268
269 for (i, node) in graph.nodes.iter().enumerate() {
270 if i >= depth_limit {
271 let _ = writeln!(
272 out,
273 "{}... ({} more nodes)",
274 config.indent,
275 graph.nodes.len() - i
276 );
277 break;
278 }
279 Self::render_node(&mut out, graph, node, i, config);
280 }
281
282 let _ = writeln!(out, "===================");
283 out
284 }
285
286 fn render_node(
287 out: &mut String,
288 graph: &EinsumGraph,
289 node: &tensorlogic_ir::EinsumNode,
290 idx: usize,
291 config: &VisualizationConfig,
292 ) {
293 let indent = &config.indent;
294 let _ = write!(out, "{}[{}] ", indent, idx);
295
296 let _ = writeln!(out, "{}", node.operation_description());
298
299 if config.show_details {
300 let input_names: Vec<String> = node
302 .inputs
303 .iter()
304 .map(|&i| {
305 graph
306 .tensors
307 .get(i)
308 .cloned()
309 .unwrap_or_else(|| format!("?{}", i))
310 })
311 .collect();
312 let _ = writeln!(
313 out,
314 "{}{} inputs: [{}]",
315 indent,
316 indent,
317 input_names.join(", ")
318 );
319
320 let output_names: Vec<String> = node
322 .outputs
323 .iter()
324 .map(|&i| {
325 graph
326 .tensors
327 .get(i)
328 .cloned()
329 .unwrap_or_else(|| format!("?{}", i))
330 })
331 .collect();
332 let _ = writeln!(
333 out,
334 "{}{} outputs: [{}]",
335 indent,
336 indent,
337 output_names.join(", ")
338 );
339 }
340 }
341}
342
343#[derive(Debug, Clone)]
349pub struct GraphSummary {
350 pub node_count: usize,
352 pub tensor_count: usize,
354 pub output_count: usize,
356 pub input_count: usize,
358 pub max_fan_in: usize,
360 pub max_fan_out: usize,
362 pub depth: usize,
364 pub op_counts: HashMap<String, usize>,
366}
367
368impl GraphSummary {
369 pub fn compute(graph: &EinsumGraph) -> Self {
371 let node_count = graph.nodes.len();
372 let tensor_count = graph.tensors.len();
373 let output_count = graph.outputs.len();
374 let input_count = graph.inputs.len();
375
376 let max_fan_in = graph
377 .nodes
378 .iter()
379 .map(|n| n.inputs.len())
380 .max()
381 .unwrap_or(0);
382 let max_fan_out = graph
383 .nodes
384 .iter()
385 .map(|n| n.outputs.len())
386 .max()
387 .unwrap_or(0);
388
389 let mut op_counts: HashMap<String, usize> = HashMap::new();
390 for node in &graph.nodes {
391 let key = match &node.op {
392 OpType::Einsum { .. } => "Einsum",
393 OpType::ElemUnary { .. } => "ElemUnary",
394 OpType::ElemBinary { .. } => "ElemBinary",
395 OpType::Reduce { .. } => "Reduce",
396 };
397 *op_counts.entry(key.to_string()).or_insert(0) += 1;
398 }
399
400 let depth = Self::compute_depth(graph);
401
402 GraphSummary {
403 node_count,
404 tensor_count,
405 output_count,
406 input_count,
407 max_fan_in,
408 max_fan_out,
409 depth,
410 op_counts,
411 }
412 }
413
414 fn compute_depth(graph: &EinsumGraph) -> usize {
418 if graph.nodes.is_empty() {
419 return 0;
420 }
421
422 let mut producer: HashMap<usize, usize> = HashMap::new();
424 for (node_idx, node) in graph.nodes.iter().enumerate() {
425 for &out_t in &node.outputs {
426 producer.insert(out_t, node_idx);
427 }
428 }
429
430 let num_nodes = graph.nodes.len();
432 let mut memo: Vec<Option<usize>> = vec![None; num_nodes];
433
434 fn depth_of(
435 node_idx: usize,
436 graph: &EinsumGraph,
437 producer: &HashMap<usize, usize>,
438 memo: &mut [Option<usize>],
439 visited: &mut HashSet<usize>,
440 ) -> usize {
441 if let Some(d) = memo[node_idx] {
442 return d;
443 }
444 if !visited.insert(node_idx) {
446 return 0;
447 }
448 let node = &graph.nodes[node_idx];
449 let mut max_pred = 0usize;
450 for &inp_t in &node.inputs {
451 if let Some(&pred_node) = producer.get(&inp_t) {
452 let d = depth_of(pred_node, graph, producer, memo, visited);
453 if d + 1 > max_pred {
454 max_pred = d + 1;
455 }
456 }
457 }
458 memo[node_idx] = Some(max_pred);
459 max_pred
460 }
461
462 let mut max_depth = 0usize;
463 for i in 0..num_nodes {
464 let mut visited = HashSet::new();
465 let d = depth_of(i, graph, &producer, &mut memo, &mut visited);
466 if d > max_depth {
467 max_depth = d;
468 }
469 }
470
471 max_depth + 1
473 }
474
475 pub fn display(&self) -> String {
477 let mut out = String::new();
478 let _ = writeln!(out, "Graph Summary:");
479 let _ = writeln!(out, " Nodes: {}", self.node_count);
480 let _ = writeln!(out, " Tensors: {}", self.tensor_count);
481 let _ = writeln!(out, " Inputs: {}", self.input_count);
482 let _ = writeln!(out, " Outputs: {}", self.output_count);
483 let _ = writeln!(out, " Depth: {}", self.depth);
484 let _ = writeln!(out, " Max fan-in: {}", self.max_fan_in);
485 let _ = writeln!(out, " Max fan-out: {}", self.max_fan_out);
486 if !self.op_counts.is_empty() {
487 let _ = writeln!(out, " Operations:");
488 let mut sorted: Vec<_> = self.op_counts.iter().collect();
489 sorted.sort_by_key(|(k, _)| (*k).clone());
490 for (op, count) in sorted {
491 let _ = writeln!(out, " {}: {}", op, count);
492 }
493 }
494 out
495 }
496}
497
498#[cfg(test)]
503mod tests {
504 use super::*;
505 use tensorlogic_ir::{EinsumGraph, EinsumNode};
506
507 fn empty_graph() -> EinsumGraph {
509 EinsumGraph::new()
510 }
511
512 fn small_graph() -> EinsumGraph {
514 let mut g = EinsumGraph::new();
515 let a = g.add_tensor("a".to_string());
516 let b = g.add_tensor("b".to_string());
517 let c = g.add_tensor("c".to_string());
518 let d = g.add_tensor("d".to_string());
519 g.inputs = vec![a, b];
520 g.outputs = vec![d];
521 g.add_node(EinsumNode::elem_binary("add", a, b, c))
522 .expect("node add");
523 g.add_node(EinsumNode::elem_unary("relu", c, d))
524 .expect("node relu");
525 g
526 }
527
528 #[test]
533 fn test_dot_export_empty_graph() {
534 let g = empty_graph();
535 let dot = DotExporter::export(&g, &VisualizationConfig::default());
536 assert!(dot.contains("digraph"));
537 }
538
539 #[test]
540 fn test_dot_export_contains_nodes() {
541 let g = small_graph();
542 let dot = DotExporter::export(&g, &VisualizationConfig::default());
543 assert!(dot.contains("op_0"));
544 assert!(dot.contains("op_1"));
545 }
546
547 #[test]
548 fn test_dot_export_contains_edges() {
549 let g = small_graph();
550 let dot = DotExporter::export(&g, &VisualizationConfig::default());
551 assert!(dot.contains("tensor_0 -> op_0"));
553 assert!(dot.contains("tensor_1 -> op_0"));
554 assert!(dot.contains("op_0 -> tensor_2"));
556 assert!(dot.contains("tensor_2 -> op_1"));
558 assert!(dot.contains("op_1 -> tensor_3"));
560 }
561
562 #[test]
563 fn test_dot_export_no_color() {
564 let g = small_graph();
565 let config = VisualizationConfig::new().with_color(false);
566 let dot = DotExporter::export(&g, &config);
567 assert!(!dot.contains("fillcolor"));
569 }
570
571 #[test]
572 fn test_dot_export_minimal_config() {
573 let g = small_graph();
574 let full = DotExporter::export(&g, &VisualizationConfig::default());
575 let minimal = DotExporter::export(&g, &VisualizationConfig::minimal());
576 assert!(minimal.contains("digraph"));
578 assert!(minimal.len() <= full.len());
579 }
580
581 #[test]
582 fn test_write_dot_file() {
583 let g = small_graph();
584 let dir = std::env::temp_dir();
585 let path = dir.join("tensorlogic_test_viz.dot");
586 write_dot_file(&path, &g, &VisualizationConfig::default()).expect("should write file");
587 let contents = std::fs::read_to_string(&path).expect("should read file");
588 assert!(contents.contains("digraph"));
589 let _ = std::fs::remove_file(&path);
590 }
591
592 #[test]
597 fn test_ascii_render_header() {
598 let g = empty_graph();
599 let ascii = AsciiRenderer::render(&g, &VisualizationConfig::default());
600 assert!(ascii.starts_with("=== EinsumGraph ==="));
601 }
602
603 #[test]
604 fn test_ascii_render_node_count() {
605 let g = small_graph();
606 let ascii = AsciiRenderer::render(&g, &VisualizationConfig::default());
607 assert!(ascii.contains("Nodes: 2"));
608 }
609
610 #[test]
611 fn test_ascii_render_output_count() {
612 let g = small_graph();
613 let ascii = AsciiRenderer::render(&g, &VisualizationConfig::default());
614 assert!(ascii.contains("Outputs: [d]"));
616 }
617
618 #[test]
619 fn test_ascii_render_details() {
620 let g = small_graph();
621 let config = VisualizationConfig::new().with_details(true);
622 let ascii = AsciiRenderer::render(&g, &config);
623 assert!(ascii.contains("inputs:"));
624 assert!(ascii.contains("outputs:"));
625 }
626
627 #[test]
628 fn test_ascii_render_no_details() {
629 let g = small_graph();
630 let with_details =
631 AsciiRenderer::render(&g, &VisualizationConfig::new().with_details(true));
632 let without = AsciiRenderer::render(&g, &VisualizationConfig::new().with_details(false));
633 assert!(without.len() < with_details.len());
634 assert!(!without.contains("inputs:"));
635 }
636
637 #[test]
642 fn test_config_default() {
643 let c = VisualizationConfig::default();
644 assert!(c.show_details);
645 assert!(c.show_shapes);
646 assert_eq!(c.max_depth, 0);
647 assert!(c.use_color);
648 assert_eq!(c.indent, " ");
649 }
650
651 #[test]
652 fn test_config_builder() {
653 let c = VisualizationConfig::new()
654 .with_details(false)
655 .with_shapes(false)
656 .with_max_depth(5)
657 .with_color(false);
658 assert!(!c.show_details);
659 assert!(!c.show_shapes);
660 assert_eq!(c.max_depth, 5);
661 assert!(!c.use_color);
662 }
663
664 #[test]
665 fn test_config_minimal() {
666 let c = VisualizationConfig::minimal();
667 assert!(!c.show_details);
668 assert!(!c.show_shapes);
669 assert!(!c.show_tensor_ids);
670 assert!(!c.show_node_ids);
671 }
672
673 #[test]
678 fn test_graph_summary_empty() {
679 let g = empty_graph();
680 let s = GraphSummary::compute(&g);
681 assert_eq!(s.node_count, 0);
682 assert_eq!(s.tensor_count, 0);
683 assert_eq!(s.output_count, 0);
684 assert_eq!(s.input_count, 0);
685 assert_eq!(s.max_fan_in, 0);
686 assert_eq!(s.max_fan_out, 0);
687 assert_eq!(s.depth, 0);
688 }
689
690 #[test]
691 fn test_graph_summary_basic() {
692 let g = small_graph();
693 let s = GraphSummary::compute(&g);
694 assert_eq!(s.node_count, 2);
695 assert_eq!(s.tensor_count, 4);
696 assert_eq!(s.output_count, 1);
697 assert_eq!(s.input_count, 2);
698 assert_eq!(s.max_fan_in, 2); assert_eq!(s.max_fan_out, 1);
700 assert_eq!(s.depth, 2); assert_eq!(s.op_counts.get("ElemBinary"), Some(&1));
702 assert_eq!(s.op_counts.get("ElemUnary"), Some(&1));
703 }
704
705 #[test]
710 fn test_dot_deterministic() {
711 let g = small_graph();
712 let config = VisualizationConfig::default();
713 let a = DotExporter::export(&g, &config);
714 let b = DotExporter::export(&g, &config);
715 assert_eq!(a, b);
716 }
717
718 #[test]
719 fn test_ascii_deterministic() {
720 let g = small_graph();
721 let config = VisualizationConfig::default();
722 let a = AsciiRenderer::render(&g, &config);
723 let b = AsciiRenderer::render(&g, &config);
724 assert_eq!(a, b);
725 }
726}