Skip to main content

tensorlogic_cli/
visualization.rs

1//! Graph visualization: DOT export and ASCII rendering.
2//!
3//! Converts [`EinsumGraph`] into visual representations for debugging and documentation.
4//! Builds on top of `tensorlogic_ir::dot_export` while adding CLI-specific features
5//! such as configurable ASCII rendering, graph summary statistics, and file I/O.
6//!
7//! # Example
8//!
9//! ```rust,no_run
10//! use tensorlogic_cli::visualization::{
11//!     AsciiRenderer, DotExporter, GraphSummary, VisualizationConfig,
12//! };
13//! use tensorlogic_ir::{EinsumGraph, EinsumNode};
14//!
15//! let mut graph = EinsumGraph::new();
16//! let t0 = graph.add_tensor("x".to_string());
17//! let t1 = graph.add_tensor("y".to_string());
18//! let node = EinsumNode::elem_unary("relu", t0, t1);
19//! graph.add_node(node).expect("should add node");
20//!
21//! let config = VisualizationConfig::default();
22//! let dot = DotExporter::export(&graph, &config);
23//! let ascii = AsciiRenderer::render(&graph, &config);
24//! let summary = GraphSummary::compute(&graph);
25//! ```
26
27use std::collections::{HashMap, HashSet};
28use std::fmt::Write;
29
30use tensorlogic_ir::{DotExportOptions, EinsumGraph, OpType};
31
32// ---------------------------------------------------------------------------
33// Configuration
34// ---------------------------------------------------------------------------
35
36/// Configuration for graph visualization.
37#[derive(Debug, Clone)]
38pub struct VisualizationConfig {
39    /// Show operation details (spec strings, etc.)
40    pub show_details: bool,
41    /// Show tensor shapes in node labels
42    pub show_shapes: bool,
43    /// Maximum depth for ASCII rendering (0 = unlimited)
44    pub max_depth: usize,
45    /// Use colour in DOT output
46    pub use_color: bool,
47    /// Indent string for ASCII rendering
48    pub indent: String,
49    /// Show tensor indices alongside names
50    pub show_tensor_ids: bool,
51    /// Show node indices alongside operation labels
52    pub show_node_ids: bool,
53    /// Use horizontal (LR) layout in DOT
54    pub horizontal_layout: bool,
55    /// Cluster operations by type in DOT
56    pub cluster_by_operation: bool,
57}
58
59impl Default for VisualizationConfig {
60    fn default() -> Self {
61        VisualizationConfig {
62            show_details: true,
63            show_shapes: true,
64            max_depth: 0,
65            use_color: true,
66            indent: "  ".to_string(),
67            show_tensor_ids: false,
68            show_node_ids: true,
69            horizontal_layout: false,
70            cluster_by_operation: false,
71        }
72    }
73}
74
75impl VisualizationConfig {
76    /// Create a new default configuration.
77    pub fn new() -> Self {
78        Self::default()
79    }
80
81    /// Builder: toggle detail display.
82    pub fn with_details(mut self, v: bool) -> Self {
83        self.show_details = v;
84        self
85    }
86
87    /// Builder: toggle shape display.
88    pub fn with_shapes(mut self, v: bool) -> Self {
89        self.show_shapes = v;
90        self
91    }
92
93    /// Builder: set maximum ASCII rendering depth.
94    pub fn with_max_depth(mut self, d: usize) -> Self {
95        self.max_depth = d;
96        self
97    }
98
99    /// Builder: toggle colour in DOT output.
100    pub fn with_color(mut self, v: bool) -> Self {
101        self.use_color = v;
102        self
103    }
104
105    /// Builder: toggle tensor id display.
106    pub fn with_tensor_ids(mut self, v: bool) -> Self {
107        self.show_tensor_ids = v;
108        self
109    }
110
111    /// Builder: toggle node id display.
112    pub fn with_node_ids(mut self, v: bool) -> Self {
113        self.show_node_ids = v;
114        self
115    }
116
117    /// Builder: toggle horizontal layout.
118    pub fn with_horizontal_layout(mut self, v: bool) -> Self {
119        self.horizontal_layout = v;
120        self
121    }
122
123    /// Builder: toggle operation clustering.
124    pub fn with_clustering(mut self, v: bool) -> Self {
125        self.cluster_by_operation = v;
126        self
127    }
128
129    /// A minimal preset that hides most detail.
130    pub fn minimal() -> Self {
131        VisualizationConfig {
132            show_details: false,
133            show_shapes: false,
134            show_tensor_ids: false,
135            show_node_ids: false,
136            ..Self::default()
137        }
138    }
139
140    /// Convert to the upstream `DotExportOptions`.
141    fn to_dot_options(&self) -> DotExportOptions {
142        DotExportOptions {
143            show_tensor_ids: self.show_tensor_ids,
144            show_node_ids: self.show_node_ids,
145            show_metadata: self.show_details,
146            show_shapes: self.show_shapes,
147            cluster_by_operation: self.cluster_by_operation,
148            horizontal_layout: self.horizontal_layout,
149            highlight_tensors: Vec::new(),
150            highlight_nodes: Vec::new(),
151        }
152    }
153}
154
155// ---------------------------------------------------------------------------
156// DOT exporter
157// ---------------------------------------------------------------------------
158
159/// Export an [`EinsumGraph`] to Graphviz DOT format.
160///
161/// This wraps the lower-level `tensorlogic_ir::export_to_dot_with_options` with
162/// the CLI-level [`VisualizationConfig`] and optionally strips colour attributes
163/// when `use_color` is `false`.
164pub struct DotExporter;
165
166impl DotExporter {
167    /// Export a graph to a DOT format string.
168    pub fn export(graph: &EinsumGraph, config: &VisualizationConfig) -> String {
169        let options = config.to_dot_options();
170        let dot = tensorlogic_ir::export_to_dot_with_options(graph, &options);
171
172        if config.use_color {
173            dot
174        } else {
175            Self::strip_fill_colors(&dot)
176        }
177    }
178
179    /// Remove `fillcolor=...` and `style=filled` from a DOT string so that
180    /// the output is monochrome.
181    fn strip_fill_colors(dot: &str) -> String {
182        let mut result = String::with_capacity(dot.len());
183        for line in dot.lines() {
184            let cleaned = line
185                .replace(", style=filled", "")
186                .replace("style=filled, ", "")
187                .replace("style=filled", "");
188
189            // Remove fillcolor=<word>
190            let cleaned = strip_attr(&cleaned, "fillcolor=");
191            // Remove trailing ", ]" artifacts that may remain
192            let cleaned = cleaned.replace(", ];", "];").replace(",];", "];");
193            let _ = writeln!(result, "{}", cleaned);
194        }
195        result
196    }
197}
198
199/// Remove a DOT attribute of the form `key=value` (unquoted single word).
200fn strip_attr(line: &str, prefix: &str) -> String {
201    if let Some(start) = line.find(prefix) {
202        let before = &line[..start];
203        let after_key = &line[start + prefix.len()..];
204        // value ends at comma, space, semicolon, or ']'
205        let end = after_key
206            .find([',', ';', ']', ' '])
207            .unwrap_or(after_key.len());
208        let rest = &after_key[end..];
209        // Trim a leading ", " from rest
210        let rest = rest.strip_prefix(", ").unwrap_or(rest);
211        let rest = rest.strip_prefix(',').unwrap_or(rest);
212        format!("{}{}", before.trim_end_matches(", "), rest)
213    } else {
214        line.to_string()
215    }
216}
217
218/// Write DOT output to a file.
219pub fn write_dot_file(
220    path: &std::path::Path,
221    graph: &EinsumGraph,
222    config: &VisualizationConfig,
223) -> std::io::Result<()> {
224    let dot = DotExporter::export(graph, config);
225    std::fs::write(path, dot)
226}
227
228// ---------------------------------------------------------------------------
229// ASCII renderer
230// ---------------------------------------------------------------------------
231
232/// Render an [`EinsumGraph`] as ASCII art for terminal display.
233pub struct AsciiRenderer;
234
235impl AsciiRenderer {
236    /// Render a graph to an ASCII string.
237    pub fn render(graph: &EinsumGraph, config: &VisualizationConfig) -> String {
238        let mut out = String::new();
239
240        let _ = writeln!(out, "=== EinsumGraph ===");
241        let _ = writeln!(out, "Nodes: {}", graph.nodes.len());
242        let _ = writeln!(
243            out,
244            "Tensors: {} ({} inputs, {} outputs)",
245            graph.tensors.len(),
246            graph.inputs.len(),
247            graph.outputs.len()
248        );
249
250        // List output tensor names
251        if !graph.outputs.is_empty() {
252            let names: Vec<&str> = graph
253                .outputs
254                .iter()
255                .filter_map(|&idx| graph.tensors.get(idx).map(|s| s.as_str()))
256                .collect();
257            let _ = writeln!(out, "Outputs: [{}]", names.join(", "));
258        }
259
260        let _ = writeln!(out);
261
262        // Render each node
263        let depth_limit = if config.max_depth == 0 {
264            usize::MAX
265        } else {
266            config.max_depth
267        };
268
269        for (i, node) in graph.nodes.iter().enumerate() {
270            if i >= depth_limit {
271                let _ = writeln!(
272                    out,
273                    "{}... ({} more nodes)",
274                    config.indent,
275                    graph.nodes.len() - i
276                );
277                break;
278            }
279            Self::render_node(&mut out, graph, node, i, config);
280        }
281
282        let _ = writeln!(out, "===================");
283        out
284    }
285
286    fn render_node(
287        out: &mut String,
288        graph: &EinsumGraph,
289        node: &tensorlogic_ir::EinsumNode,
290        idx: usize,
291        config: &VisualizationConfig,
292    ) {
293        let indent = &config.indent;
294        let _ = write!(out, "{}[{}] ", indent, idx);
295
296        // Operation description
297        let _ = writeln!(out, "{}", node.operation_description());
298
299        if config.show_details {
300            // Inputs
301            let input_names: Vec<String> = node
302                .inputs
303                .iter()
304                .map(|&i| {
305                    graph
306                        .tensors
307                        .get(i)
308                        .cloned()
309                        .unwrap_or_else(|| format!("?{}", i))
310                })
311                .collect();
312            let _ = writeln!(
313                out,
314                "{}{} inputs: [{}]",
315                indent,
316                indent,
317                input_names.join(", ")
318            );
319
320            // Outputs
321            let output_names: Vec<String> = node
322                .outputs
323                .iter()
324                .map(|&i| {
325                    graph
326                        .tensors
327                        .get(i)
328                        .cloned()
329                        .unwrap_or_else(|| format!("?{}", i))
330                })
331                .collect();
332            let _ = writeln!(
333                out,
334                "{}{} outputs: [{}]",
335                indent,
336                indent,
337                output_names.join(", ")
338            );
339        }
340    }
341}
342
343// ---------------------------------------------------------------------------
344// Graph summary / statistics
345// ---------------------------------------------------------------------------
346
347/// Lightweight statistics computed from an [`EinsumGraph`].
348#[derive(Debug, Clone)]
349pub struct GraphSummary {
350    /// Number of computation nodes.
351    pub node_count: usize,
352    /// Number of named tensors.
353    pub tensor_count: usize,
354    /// Number of graph outputs.
355    pub output_count: usize,
356    /// Number of graph inputs.
357    pub input_count: usize,
358    /// Maximum fan-in (number of inputs) across all nodes.
359    pub max_fan_in: usize,
360    /// Maximum fan-out (number of outputs) across all nodes.
361    pub max_fan_out: usize,
362    /// Longest path through the dataflow graph (in nodes).
363    pub depth: usize,
364    /// Operation type distribution.
365    pub op_counts: HashMap<String, usize>,
366}
367
368impl GraphSummary {
369    /// Compute summary statistics for the given graph.
370    pub fn compute(graph: &EinsumGraph) -> Self {
371        let node_count = graph.nodes.len();
372        let tensor_count = graph.tensors.len();
373        let output_count = graph.outputs.len();
374        let input_count = graph.inputs.len();
375
376        let max_fan_in = graph
377            .nodes
378            .iter()
379            .map(|n| n.inputs.len())
380            .max()
381            .unwrap_or(0);
382        let max_fan_out = graph
383            .nodes
384            .iter()
385            .map(|n| n.outputs.len())
386            .max()
387            .unwrap_or(0);
388
389        let mut op_counts: HashMap<String, usize> = HashMap::new();
390        for node in &graph.nodes {
391            let key = match &node.op {
392                OpType::Einsum { .. } => "Einsum",
393                OpType::ElemUnary { .. } => "ElemUnary",
394                OpType::ElemBinary { .. } => "ElemBinary",
395                OpType::Reduce { .. } => "Reduce",
396            };
397            *op_counts.entry(key.to_string()).or_insert(0) += 1;
398        }
399
400        let depth = Self::compute_depth(graph);
401
402        GraphSummary {
403            node_count,
404            tensor_count,
405            output_count,
406            input_count,
407            max_fan_in,
408            max_fan_out,
409            depth,
410            op_counts,
411        }
412    }
413
414    /// Compute the longest path through the dataflow graph using topological
415    /// traversal. Each node is assigned a depth equal to 1 + max depth of any
416    /// node producing one of its input tensors.
417    fn compute_depth(graph: &EinsumGraph) -> usize {
418        if graph.nodes.is_empty() {
419            return 0;
420        }
421
422        // Build a map: tensor_idx -> producing node index
423        let mut producer: HashMap<usize, usize> = HashMap::new();
424        for (node_idx, node) in graph.nodes.iter().enumerate() {
425            for &out_t in &node.outputs {
426                producer.insert(out_t, node_idx);
427            }
428        }
429
430        // Memo for node depths
431        let num_nodes = graph.nodes.len();
432        let mut memo: Vec<Option<usize>> = vec![None; num_nodes];
433
434        fn depth_of(
435            node_idx: usize,
436            graph: &EinsumGraph,
437            producer: &HashMap<usize, usize>,
438            memo: &mut [Option<usize>],
439            visited: &mut HashSet<usize>,
440        ) -> usize {
441            if let Some(d) = memo[node_idx] {
442                return d;
443            }
444            // Cycle guard
445            if !visited.insert(node_idx) {
446                return 0;
447            }
448            let node = &graph.nodes[node_idx];
449            let mut max_pred = 0usize;
450            for &inp_t in &node.inputs {
451                if let Some(&pred_node) = producer.get(&inp_t) {
452                    let d = depth_of(pred_node, graph, producer, memo, visited);
453                    if d + 1 > max_pred {
454                        max_pred = d + 1;
455                    }
456                }
457            }
458            memo[node_idx] = Some(max_pred);
459            max_pred
460        }
461
462        let mut max_depth = 0usize;
463        for i in 0..num_nodes {
464            let mut visited = HashSet::new();
465            let d = depth_of(i, graph, &producer, &mut memo, &mut visited);
466            if d > max_depth {
467                max_depth = d;
468            }
469        }
470
471        // depth counts edges; add 1 so a single-node graph has depth 1
472        max_depth + 1
473    }
474
475    /// Pretty-print the summary.
476    pub fn display(&self) -> String {
477        let mut out = String::new();
478        let _ = writeln!(out, "Graph Summary:");
479        let _ = writeln!(out, "  Nodes:   {}", self.node_count);
480        let _ = writeln!(out, "  Tensors: {}", self.tensor_count);
481        let _ = writeln!(out, "  Inputs:  {}", self.input_count);
482        let _ = writeln!(out, "  Outputs: {}", self.output_count);
483        let _ = writeln!(out, "  Depth:   {}", self.depth);
484        let _ = writeln!(out, "  Max fan-in:  {}", self.max_fan_in);
485        let _ = writeln!(out, "  Max fan-out: {}", self.max_fan_out);
486        if !self.op_counts.is_empty() {
487            let _ = writeln!(out, "  Operations:");
488            let mut sorted: Vec<_> = self.op_counts.iter().collect();
489            sorted.sort_by_key(|(k, _)| (*k).clone());
490            for (op, count) in sorted {
491                let _ = writeln!(out, "    {}: {}", op, count);
492            }
493        }
494        out
495    }
496}
497
498// ---------------------------------------------------------------------------
499// Tests
500// ---------------------------------------------------------------------------
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use tensorlogic_ir::{EinsumGraph, EinsumNode};
506
507    /// Helper: build an empty graph.
508    fn empty_graph() -> EinsumGraph {
509        EinsumGraph::new()
510    }
511
512    /// Helper: build a small graph with two nodes.
513    fn small_graph() -> EinsumGraph {
514        let mut g = EinsumGraph::new();
515        let a = g.add_tensor("a".to_string());
516        let b = g.add_tensor("b".to_string());
517        let c = g.add_tensor("c".to_string());
518        let d = g.add_tensor("d".to_string());
519        g.inputs = vec![a, b];
520        g.outputs = vec![d];
521        g.add_node(EinsumNode::elem_binary("add", a, b, c))
522            .expect("node add");
523        g.add_node(EinsumNode::elem_unary("relu", c, d))
524            .expect("node relu");
525        g
526    }
527
528    // -----------------------------------------------------------------------
529    // DOT export tests
530    // -----------------------------------------------------------------------
531
532    #[test]
533    fn test_dot_export_empty_graph() {
534        let g = empty_graph();
535        let dot = DotExporter::export(&g, &VisualizationConfig::default());
536        assert!(dot.contains("digraph"));
537    }
538
539    #[test]
540    fn test_dot_export_contains_nodes() {
541        let g = small_graph();
542        let dot = DotExporter::export(&g, &VisualizationConfig::default());
543        assert!(dot.contains("op_0"));
544        assert!(dot.contains("op_1"));
545    }
546
547    #[test]
548    fn test_dot_export_contains_edges() {
549        let g = small_graph();
550        let dot = DotExporter::export(&g, &VisualizationConfig::default());
551        // a->add, b->add
552        assert!(dot.contains("tensor_0 -> op_0"));
553        assert!(dot.contains("tensor_1 -> op_0"));
554        // add->c
555        assert!(dot.contains("op_0 -> tensor_2"));
556        // c->relu
557        assert!(dot.contains("tensor_2 -> op_1"));
558        // relu->d
559        assert!(dot.contains("op_1 -> tensor_3"));
560    }
561
562    #[test]
563    fn test_dot_export_no_color() {
564        let g = small_graph();
565        let config = VisualizationConfig::new().with_color(false);
566        let dot = DotExporter::export(&g, &config);
567        // fillcolor attributes should be stripped
568        assert!(!dot.contains("fillcolor"));
569    }
570
571    #[test]
572    fn test_dot_export_minimal_config() {
573        let g = small_graph();
574        let full = DotExporter::export(&g, &VisualizationConfig::default());
575        let minimal = DotExporter::export(&g, &VisualizationConfig::minimal());
576        // Minimal should still be valid DOT but may be shorter (no node ids etc.)
577        assert!(minimal.contains("digraph"));
578        assert!(minimal.len() <= full.len());
579    }
580
581    #[test]
582    fn test_write_dot_file() {
583        let g = small_graph();
584        let dir = std::env::temp_dir();
585        let path = dir.join("tensorlogic_test_viz.dot");
586        write_dot_file(&path, &g, &VisualizationConfig::default()).expect("should write file");
587        let contents = std::fs::read_to_string(&path).expect("should read file");
588        assert!(contents.contains("digraph"));
589        let _ = std::fs::remove_file(&path);
590    }
591
592    // -----------------------------------------------------------------------
593    // ASCII renderer tests
594    // -----------------------------------------------------------------------
595
596    #[test]
597    fn test_ascii_render_header() {
598        let g = empty_graph();
599        let ascii = AsciiRenderer::render(&g, &VisualizationConfig::default());
600        assert!(ascii.starts_with("=== EinsumGraph ==="));
601    }
602
603    #[test]
604    fn test_ascii_render_node_count() {
605        let g = small_graph();
606        let ascii = AsciiRenderer::render(&g, &VisualizationConfig::default());
607        assert!(ascii.contains("Nodes: 2"));
608    }
609
610    #[test]
611    fn test_ascii_render_output_count() {
612        let g = small_graph();
613        let ascii = AsciiRenderer::render(&g, &VisualizationConfig::default());
614        // Output tensor is "d"
615        assert!(ascii.contains("Outputs: [d]"));
616    }
617
618    #[test]
619    fn test_ascii_render_details() {
620        let g = small_graph();
621        let config = VisualizationConfig::new().with_details(true);
622        let ascii = AsciiRenderer::render(&g, &config);
623        assert!(ascii.contains("inputs:"));
624        assert!(ascii.contains("outputs:"));
625    }
626
627    #[test]
628    fn test_ascii_render_no_details() {
629        let g = small_graph();
630        let with_details =
631            AsciiRenderer::render(&g, &VisualizationConfig::new().with_details(true));
632        let without = AsciiRenderer::render(&g, &VisualizationConfig::new().with_details(false));
633        assert!(without.len() < with_details.len());
634        assert!(!without.contains("inputs:"));
635    }
636
637    // -----------------------------------------------------------------------
638    // Config tests
639    // -----------------------------------------------------------------------
640
641    #[test]
642    fn test_config_default() {
643        let c = VisualizationConfig::default();
644        assert!(c.show_details);
645        assert!(c.show_shapes);
646        assert_eq!(c.max_depth, 0);
647        assert!(c.use_color);
648        assert_eq!(c.indent, "  ");
649    }
650
651    #[test]
652    fn test_config_builder() {
653        let c = VisualizationConfig::new()
654            .with_details(false)
655            .with_shapes(false)
656            .with_max_depth(5)
657            .with_color(false);
658        assert!(!c.show_details);
659        assert!(!c.show_shapes);
660        assert_eq!(c.max_depth, 5);
661        assert!(!c.use_color);
662    }
663
664    #[test]
665    fn test_config_minimal() {
666        let c = VisualizationConfig::minimal();
667        assert!(!c.show_details);
668        assert!(!c.show_shapes);
669        assert!(!c.show_tensor_ids);
670        assert!(!c.show_node_ids);
671    }
672
673    // -----------------------------------------------------------------------
674    // Graph summary tests
675    // -----------------------------------------------------------------------
676
677    #[test]
678    fn test_graph_summary_empty() {
679        let g = empty_graph();
680        let s = GraphSummary::compute(&g);
681        assert_eq!(s.node_count, 0);
682        assert_eq!(s.tensor_count, 0);
683        assert_eq!(s.output_count, 0);
684        assert_eq!(s.input_count, 0);
685        assert_eq!(s.max_fan_in, 0);
686        assert_eq!(s.max_fan_out, 0);
687        assert_eq!(s.depth, 0);
688    }
689
690    #[test]
691    fn test_graph_summary_basic() {
692        let g = small_graph();
693        let s = GraphSummary::compute(&g);
694        assert_eq!(s.node_count, 2);
695        assert_eq!(s.tensor_count, 4);
696        assert_eq!(s.output_count, 1);
697        assert_eq!(s.input_count, 2);
698        assert_eq!(s.max_fan_in, 2); // binary add has 2 inputs
699        assert_eq!(s.max_fan_out, 1);
700        assert_eq!(s.depth, 2); // add -> relu chain
701        assert_eq!(s.op_counts.get("ElemBinary"), Some(&1));
702        assert_eq!(s.op_counts.get("ElemUnary"), Some(&1));
703    }
704
705    // -----------------------------------------------------------------------
706    // Determinism tests
707    // -----------------------------------------------------------------------
708
709    #[test]
710    fn test_dot_deterministic() {
711        let g = small_graph();
712        let config = VisualizationConfig::default();
713        let a = DotExporter::export(&g, &config);
714        let b = DotExporter::export(&g, &config);
715        assert_eq!(a, b);
716    }
717
718    #[test]
719    fn test_ascii_deterministic() {
720        let g = small_graph();
721        let config = VisualizationConfig::default();
722        let a = AsciiRenderer::render(&g, &config);
723        let b = AsciiRenderer::render(&g, &config);
724        assert_eq!(a, b);
725    }
726}