Skip to main content

tensorlogic_ir/graph/
dot_export.rs

1//! DOT format export for graph visualization.
2//!
3//! This module provides utilities to export `EinsumGraph` to DOT format
4//! for visualization with Graphviz and similar tools.
5//!
6//! # Example
7//!
8//! ```
9//! use tensorlogic_ir::{EinsumGraph, EinsumNode};
10//!
11//! let mut graph = EinsumGraph::new();
12//! let t0 = graph.add_tensor("input".to_string());
13//! let t1 = graph.add_tensor("output".to_string());
14//! let node = EinsumNode::elem_unary("relu", t0, t1);
15//! graph.add_node(node).unwrap();
16//!
17//! let dot = tensorlogic_ir::export_to_dot(&graph);
18//! println!("{}", dot);
19//! ```
20
21use crate::graph::{EinsumGraph, EinsumNode, OpType};
22use std::collections::{HashMap, HashSet};
23use std::fmt::Write as FmtWrite;
24
25/// Export an `EinsumGraph` to DOT format.
26///
27/// The resulting DOT string can be rendered with Graphviz:
28/// ```bash
29/// echo "..." | dot -Tpng > graph.png
30/// echo "..." | dot -Tsvg > graph.svg
31/// ```
32///
33/// # Layout Options
34///
35/// The generated DOT uses the following attributes:
36/// - **Tensor nodes**: Boxes with blue color
37/// - **Operation nodes**: Ellipses with green color
38/// - **Edges**: Show data flow from inputs to operations to outputs
39///
40/// # Example
41///
42/// ```
43/// use tensorlogic_ir::{EinsumGraph, EinsumNode, export_to_dot};
44///
45/// let mut graph = EinsumGraph::new();
46/// let input = graph.add_tensor("x".to_string());
47/// let output = graph.add_tensor("y".to_string());
48///
49/// let node = EinsumNode::elem_unary("relu", input, output);
50/// graph.add_node(node).unwrap();
51///
52/// let dot = export_to_dot(&graph);
53/// assert!(dot.contains("digraph"));
54/// assert!(dot.contains("relu"));
55/// ```
56pub fn export_to_dot(graph: &EinsumGraph) -> String {
57    let mut output = String::new();
58    export_to_dot_writer(graph, &mut output).expect("String write should not fail");
59    output
60}
61
62/// Export an `EinsumGraph` to DOT format with custom options.
63///
64/// # Options
65///
66/// - `show_tensor_ids`: Show tensor indices in labels
67/// - `show_node_ids`: Show node indices in labels
68/// - `show_metadata`: Include metadata in node labels
69/// - `cluster_by_operation`: Group operations by type
70/// - `horizontal_layout`: Use left-to-right layout instead of top-to-bottom
71///
72/// # Example
73///
74/// ```
75/// use tensorlogic_ir::{EinsumGraph, EinsumNode, DotExportOptions, export_to_dot_with_options};
76///
77/// let mut graph = EinsumGraph::new();
78/// let t0 = graph.add_tensor("input".to_string());
79/// let t1 = graph.add_tensor("output".to_string());
80/// let node = EinsumNode::elem_unary("sigmoid", t0, t1);
81/// graph.add_node(node).unwrap();
82///
83/// let options = DotExportOptions {
84///     show_tensor_ids: true,
85///     show_node_ids: true,
86///     horizontal_layout: true,
87///     ..Default::default()
88/// };
89///
90/// let dot = export_to_dot_with_options(&graph, &options);
91/// assert!(dot.contains("rankdir=LR"));
92/// ```
93pub fn export_to_dot_with_options(graph: &EinsumGraph, options: &DotExportOptions) -> String {
94    let mut output = String::new();
95    export_to_dot_writer_with_options(graph, &mut output, options)
96        .expect("String write should not fail");
97    output
98}
99
100/// Options for DOT export customization.
101#[derive(Debug, Clone, Default)]
102pub struct DotExportOptions {
103    /// Show tensor indices in labels (e.g., "tensor_0 \[0\]")
104    pub show_tensor_ids: bool,
105    /// Show node indices in labels (e.g., "op_0")
106    pub show_node_ids: bool,
107    /// Include metadata in node labels
108    pub show_metadata: bool,
109    /// Group operations by type (einsum, elem_unary, elem_binary, reduce)
110    pub cluster_by_operation: bool,
111    /// Use horizontal (left-to-right) layout instead of vertical
112    pub horizontal_layout: bool,
113    /// Include tensor shapes in labels (if available)
114    pub show_shapes: bool,
115    /// Highlight specific tensors (by name or index)
116    pub highlight_tensors: Vec<String>,
117    /// Highlight specific operations (by index)
118    pub highlight_nodes: Vec<usize>,
119}
120
121/// Export to DOT format writing to a generic writer.
122pub fn export_to_dot_writer<W: FmtWrite>(graph: &EinsumGraph, writer: &mut W) -> std::fmt::Result {
123    let options = DotExportOptions::default();
124    export_to_dot_writer_with_options(graph, writer, &options)
125}
126
127/// Export to DOT format with options, writing to a generic writer.
128pub fn export_to_dot_writer_with_options<W: FmtWrite>(
129    graph: &EinsumGraph,
130    writer: &mut W,
131    options: &DotExportOptions,
132) -> std::fmt::Result {
133    writeln!(writer, "digraph EinsumGraph {{")?;
134
135    // Graph attributes
136    writeln!(writer, "  // Graph styling")?;
137    writeln!(writer, "  graph [fontname=\"Helvetica\", fontsize=10];")?;
138    writeln!(writer, "  node [fontname=\"Helvetica\", fontsize=10];")?;
139    writeln!(writer, "  edge [fontname=\"Helvetica\", fontsize=9];")?;
140
141    if options.horizontal_layout {
142        writeln!(writer, "  rankdir=LR;")?;
143    }
144
145    writeln!(writer)?;
146
147    // Group operations by type if requested
148    let mut op_clusters: HashMap<String, Vec<usize>> = HashMap::new();
149    if options.cluster_by_operation {
150        for (idx, node) in graph.nodes.iter().enumerate() {
151            let cluster_name = match &node.op {
152                OpType::Einsum { .. } => "einsum",
153                OpType::ElemUnary { .. } => "elem_unary",
154                OpType::ElemBinary { .. } => "elem_binary",
155                OpType::Reduce { .. } => "reduce",
156            };
157            op_clusters
158                .entry(cluster_name.to_string())
159                .or_default()
160                .push(idx);
161        }
162    }
163
164    // Collect input and output tensors
165    let mut used_tensors = HashSet::new();
166    for node in &graph.nodes {
167        for &input in &node.inputs {
168            used_tensors.insert(input);
169        }
170        for &output in &node.outputs {
171            used_tensors.insert(output);
172        }
173    }
174
175    // Write tensor nodes
176    writeln!(writer, "  // Tensor nodes")?;
177    for (idx, tensor_name) in graph.tensors.iter().enumerate() {
178        if !used_tensors.contains(&idx) && !graph.inputs.contains(&idx) {
179            continue; // Skip unused tensors
180        }
181
182        let label = if options.show_tensor_ids {
183            format!("{} [{}]", escape_label(tensor_name), idx)
184        } else {
185            escape_label(tensor_name)
186        };
187
188        let is_input = graph.inputs.contains(&idx);
189        let is_output = graph.outputs.contains(&idx);
190        let is_highlighted = options.highlight_tensors.contains(tensor_name)
191            || options
192                .highlight_tensors
193                .contains(&format!("tensor_{}", idx));
194
195        let color = if is_highlighted {
196            "red"
197        } else if is_input && is_output {
198            "purple"
199        } else if is_input {
200            "lightblue"
201        } else if is_output {
202            "lightgreen"
203        } else {
204            "lightyellow"
205        };
206
207        writeln!(
208            writer,
209            "  tensor_{} [label=\"{}\", shape=box, style=filled, fillcolor={}];",
210            idx, label, color
211        )?;
212    }
213
214    writeln!(writer)?;
215
216    // Write operation nodes, possibly clustered
217    if options.cluster_by_operation && !op_clusters.is_empty() {
218        for (cluster_name, node_indices) in &op_clusters {
219            writeln!(
220                writer,
221                "  subgraph cluster_{} {{",
222                cluster_name.replace('.', "_")
223            )?;
224            writeln!(writer, "    label=\"{}\";", cluster_name)?;
225            writeln!(writer, "    style=dashed;")?;
226
227            for &node_idx in node_indices {
228                write_operation_node(writer, &graph.nodes[node_idx], node_idx, options)?;
229            }
230
231            writeln!(writer, "  }}")?;
232            writeln!(writer)?;
233        }
234    } else {
235        writeln!(writer, "  // Operation nodes")?;
236        for (idx, node) in graph.nodes.iter().enumerate() {
237            write_operation_node(writer, node, idx, options)?;
238        }
239        writeln!(writer)?;
240    }
241
242    // Write edges
243    writeln!(writer, "  // Data flow edges")?;
244    for (node_idx, node) in graph.nodes.iter().enumerate() {
245        // Input edges
246        for &input_tensor in &node.inputs {
247            writeln!(writer, "  tensor_{} -> op_{};", input_tensor, node_idx)?;
248        }
249
250        // Output edges
251        for &output_tensor in &node.outputs {
252            writeln!(writer, "  op_{} -> tensor_{};", node_idx, output_tensor)?;
253        }
254    }
255
256    writeln!(writer, "}}")?;
257
258    Ok(())
259}
260
261/// Write a single operation node to the DOT output.
262fn write_operation_node<W: FmtWrite>(
263    writer: &mut W,
264    node: &EinsumNode,
265    idx: usize,
266    options: &DotExportOptions,
267) -> std::fmt::Result {
268    let (op_type, op_label) = match &node.op {
269        OpType::Einsum { spec } => ("einsum", format!("einsum\\n{}", escape_label(spec))),
270        OpType::ElemUnary { op } => ("elem_unary", format!("{}(·)", escape_label(op))),
271        OpType::ElemBinary { op } => ("elem_binary", format!("{}(·,·)", escape_label(op))),
272        OpType::Reduce { op, axes } => ("reduce", format!("{}(axes={:?})", escape_label(op), axes)),
273    };
274
275    let label = if options.show_node_ids {
276        format!("{}\\n[op_{}]", op_label, idx)
277    } else {
278        op_label
279    };
280
281    let is_highlighted = options.highlight_nodes.contains(&idx);
282    let color = if is_highlighted {
283        "orange"
284    } else {
285        match op_type {
286            "einsum" => "lightcyan",
287            "elem_unary" => "lightgreen",
288            "elem_binary" => "lightyellow",
289            "reduce" => "lightpink",
290            _ => "white",
291        }
292    };
293
294    writeln!(
295        writer,
296        "  op_{} [label=\"{}\", shape=ellipse, style=filled, fillcolor={}];",
297        idx, label, color
298    )?;
299
300    Ok(())
301}
302
303/// Escape special characters in DOT labels.
304fn escape_label(s: &str) -> String {
305    s.replace('\\', "\\\\")
306        .replace('"', "\\\"")
307        .replace('\n', "\\n")
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use crate::{EinsumGraph, EinsumNode};
314
315    #[test]
316    fn test_export_empty_graph() {
317        let graph = EinsumGraph::new();
318        let dot = export_to_dot(&graph);
319        assert!(dot.contains("digraph EinsumGraph"));
320    }
321
322    #[test]
323    fn test_export_simple_operation() {
324        let mut graph = EinsumGraph::new();
325        let t0 = graph.add_tensor("input".to_string());
326        let t1 = graph.add_tensor("output".to_string());
327
328        let node = EinsumNode::elem_unary("relu", t0, t1);
329        graph.add_node(node).unwrap();
330
331        let dot = export_to_dot(&graph);
332        assert!(dot.contains("relu"));
333        assert!(dot.contains("tensor_0"));
334        assert!(dot.contains("tensor_1"));
335        assert!(dot.contains("op_0"));
336    }
337
338    #[test]
339    fn test_export_with_einsum() {
340        let mut graph = EinsumGraph::new();
341        let t0 = graph.add_tensor("A".to_string());
342        let t1 = graph.add_tensor("B".to_string());
343        let t2 = graph.add_tensor("C".to_string());
344
345        let node = EinsumNode::einsum("ij,jk->ik", vec![t0, t1], vec![t2]);
346        graph.add_node(node).unwrap();
347
348        let dot = export_to_dot(&graph);
349        assert!(dot.contains("einsum"));
350        assert!(dot.contains("ij,jk->ik"));
351    }
352
353    #[test]
354    fn test_export_with_options() {
355        let mut graph = EinsumGraph::new();
356        let t0 = graph.add_tensor("x".to_string());
357        let t1 = graph.add_tensor("y".to_string());
358
359        let node = EinsumNode::elem_unary("sigmoid", t0, t1);
360        graph.add_node(node).unwrap();
361
362        let options = DotExportOptions {
363            show_tensor_ids: true,
364            show_node_ids: true,
365            horizontal_layout: true,
366            ..Default::default()
367        };
368
369        let dot = export_to_dot_with_options(&graph, &options);
370        assert!(dot.contains("rankdir=LR"));
371        assert!(dot.contains("[0]")); // Tensor ID
372        assert!(dot.contains("[op_0]")); // Node ID
373    }
374
375    #[test]
376    fn test_export_with_clustering() {
377        let mut graph = EinsumGraph::new();
378        let t0 = graph.add_tensor("a".to_string());
379        let t1 = graph.add_tensor("b".to_string());
380        let t2 = graph.add_tensor("c".to_string());
381        let t3 = graph.add_tensor("d".to_string());
382
383        graph
384            .add_node(EinsumNode::elem_unary("relu", t0, t1))
385            .unwrap();
386        graph
387            .add_node(EinsumNode::elem_unary("sigmoid", t1, t2))
388            .unwrap();
389        graph
390            .add_node(EinsumNode::elem_binary("add", t2, t0, t3))
391            .unwrap();
392
393        let options = DotExportOptions {
394            cluster_by_operation: true,
395            ..Default::default()
396        };
397
398        let dot = export_to_dot_with_options(&graph, &options);
399        assert!(dot.contains("subgraph cluster_elem_unary"));
400        assert!(dot.contains("subgraph cluster_elem_binary"));
401    }
402
403    #[test]
404    fn test_export_with_highlights() {
405        let mut graph = EinsumGraph::new();
406        let t0 = graph.add_tensor("input".to_string());
407        let t1 = graph.add_tensor("hidden".to_string());
408        let t2 = graph.add_tensor("output".to_string());
409
410        graph
411            .add_node(EinsumNode::elem_unary("relu", t0, t1))
412            .unwrap();
413        graph
414            .add_node(EinsumNode::elem_unary("softmax", t1, t2))
415            .unwrap();
416
417        let options = DotExportOptions {
418            highlight_tensors: vec!["output".to_string()],
419            highlight_nodes: vec![0],
420            ..Default::default()
421        };
422
423        let dot = export_to_dot_with_options(&graph, &options);
424        assert!(dot.contains("red")); // Highlighted tensor
425        assert!(dot.contains("orange")); // Highlighted operation
426    }
427
428    #[test]
429    fn test_label_escaping() {
430        assert_eq!(escape_label("hello\"world"), "hello\\\"world");
431        assert_eq!(escape_label("line1\nline2"), "line1\\nline2");
432        assert_eq!(escape_label("path\\to\\file"), "path\\\\to\\\\file");
433    }
434
435    #[test]
436    fn test_complex_graph_export() {
437        let mut graph = EinsumGraph::new();
438
439        // Build a more complex graph: (a + b) * c
440        let a = graph.add_tensor("a".to_string());
441        let b = graph.add_tensor("b".to_string());
442        let c = graph.add_tensor("c".to_string());
443        let sum = graph.add_tensor("sum".to_string());
444        let result = graph.add_tensor("result".to_string());
445
446        graph.inputs = vec![a, b, c];
447        graph.outputs = vec![result];
448
449        graph
450            .add_node(EinsumNode::elem_binary("add", a, b, sum))
451            .unwrap();
452        graph
453            .add_node(EinsumNode::elem_binary("multiply", sum, c, result))
454            .unwrap();
455
456        let dot = export_to_dot(&graph);
457
458        // Verify structure
459        assert!(dot.contains("tensor_0")); // a
460        assert!(dot.contains("tensor_1")); // b
461        assert!(dot.contains("tensor_2")); // c
462        assert!(dot.contains("tensor_3")); // sum
463        assert!(dot.contains("tensor_4")); // result
464        assert!(dot.contains("op_0")); // add
465        assert!(dot.contains("op_1")); // multiply
466
467        // Verify edges
468        assert!(dot.contains("tensor_0 -> op_0")); // a -> add
469        assert!(dot.contains("tensor_1 -> op_0")); // b -> add
470        assert!(dot.contains("op_0 -> tensor_3")); // add -> sum
471        assert!(dot.contains("tensor_3 -> op_1")); // sum -> multiply
472        assert!(dot.contains("tensor_2 -> op_1")); // c -> multiply
473        assert!(dot.contains("op_1 -> tensor_4")); // multiply -> result
474    }
475}