1use crate::graph::{EinsumGraph, EinsumNode, OpType};
22use std::collections::{HashMap, HashSet};
23use std::fmt::Write as FmtWrite;
24
25pub fn export_to_dot(graph: &EinsumGraph) -> String {
57 let mut output = String::new();
58 export_to_dot_writer(graph, &mut output).expect("String write should not fail");
59 output
60}
61
62pub fn export_to_dot_with_options(graph: &EinsumGraph, options: &DotExportOptions) -> String {
94 let mut output = String::new();
95 export_to_dot_writer_with_options(graph, &mut output, options)
96 .expect("String write should not fail");
97 output
98}
99
100#[derive(Debug, Clone, Default)]
102pub struct DotExportOptions {
103 pub show_tensor_ids: bool,
105 pub show_node_ids: bool,
107 pub show_metadata: bool,
109 pub cluster_by_operation: bool,
111 pub horizontal_layout: bool,
113 pub show_shapes: bool,
115 pub highlight_tensors: Vec<String>,
117 pub highlight_nodes: Vec<usize>,
119}
120
121pub fn export_to_dot_writer<W: FmtWrite>(graph: &EinsumGraph, writer: &mut W) -> std::fmt::Result {
123 let options = DotExportOptions::default();
124 export_to_dot_writer_with_options(graph, writer, &options)
125}
126
127pub fn export_to_dot_writer_with_options<W: FmtWrite>(
129 graph: &EinsumGraph,
130 writer: &mut W,
131 options: &DotExportOptions,
132) -> std::fmt::Result {
133 writeln!(writer, "digraph EinsumGraph {{")?;
134
135 writeln!(writer, " // Graph styling")?;
137 writeln!(writer, " graph [fontname=\"Helvetica\", fontsize=10];")?;
138 writeln!(writer, " node [fontname=\"Helvetica\", fontsize=10];")?;
139 writeln!(writer, " edge [fontname=\"Helvetica\", fontsize=9];")?;
140
141 if options.horizontal_layout {
142 writeln!(writer, " rankdir=LR;")?;
143 }
144
145 writeln!(writer)?;
146
147 let mut op_clusters: HashMap<String, Vec<usize>> = HashMap::new();
149 if options.cluster_by_operation {
150 for (idx, node) in graph.nodes.iter().enumerate() {
151 let cluster_name = match &node.op {
152 OpType::Einsum { .. } => "einsum",
153 OpType::ElemUnary { .. } => "elem_unary",
154 OpType::ElemBinary { .. } => "elem_binary",
155 OpType::Reduce { .. } => "reduce",
156 };
157 op_clusters
158 .entry(cluster_name.to_string())
159 .or_default()
160 .push(idx);
161 }
162 }
163
164 let mut used_tensors = HashSet::new();
166 for node in &graph.nodes {
167 for &input in &node.inputs {
168 used_tensors.insert(input);
169 }
170 for &output in &node.outputs {
171 used_tensors.insert(output);
172 }
173 }
174
175 writeln!(writer, " // Tensor nodes")?;
177 for (idx, tensor_name) in graph.tensors.iter().enumerate() {
178 if !used_tensors.contains(&idx) && !graph.inputs.contains(&idx) {
179 continue; }
181
182 let label = if options.show_tensor_ids {
183 format!("{} [{}]", escape_label(tensor_name), idx)
184 } else {
185 escape_label(tensor_name)
186 };
187
188 let is_input = graph.inputs.contains(&idx);
189 let is_output = graph.outputs.contains(&idx);
190 let is_highlighted = options.highlight_tensors.contains(tensor_name)
191 || options
192 .highlight_tensors
193 .contains(&format!("tensor_{}", idx));
194
195 let color = if is_highlighted {
196 "red"
197 } else if is_input && is_output {
198 "purple"
199 } else if is_input {
200 "lightblue"
201 } else if is_output {
202 "lightgreen"
203 } else {
204 "lightyellow"
205 };
206
207 writeln!(
208 writer,
209 " tensor_{} [label=\"{}\", shape=box, style=filled, fillcolor={}];",
210 idx, label, color
211 )?;
212 }
213
214 writeln!(writer)?;
215
216 if options.cluster_by_operation && !op_clusters.is_empty() {
218 for (cluster_name, node_indices) in &op_clusters {
219 writeln!(
220 writer,
221 " subgraph cluster_{} {{",
222 cluster_name.replace('.', "_")
223 )?;
224 writeln!(writer, " label=\"{}\";", cluster_name)?;
225 writeln!(writer, " style=dashed;")?;
226
227 for &node_idx in node_indices {
228 write_operation_node(writer, &graph.nodes[node_idx], node_idx, options)?;
229 }
230
231 writeln!(writer, " }}")?;
232 writeln!(writer)?;
233 }
234 } else {
235 writeln!(writer, " // Operation nodes")?;
236 for (idx, node) in graph.nodes.iter().enumerate() {
237 write_operation_node(writer, node, idx, options)?;
238 }
239 writeln!(writer)?;
240 }
241
242 writeln!(writer, " // Data flow edges")?;
244 for (node_idx, node) in graph.nodes.iter().enumerate() {
245 for &input_tensor in &node.inputs {
247 writeln!(writer, " tensor_{} -> op_{};", input_tensor, node_idx)?;
248 }
249
250 for &output_tensor in &node.outputs {
252 writeln!(writer, " op_{} -> tensor_{};", node_idx, output_tensor)?;
253 }
254 }
255
256 writeln!(writer, "}}")?;
257
258 Ok(())
259}
260
261fn write_operation_node<W: FmtWrite>(
263 writer: &mut W,
264 node: &EinsumNode,
265 idx: usize,
266 options: &DotExportOptions,
267) -> std::fmt::Result {
268 let (op_type, op_label) = match &node.op {
269 OpType::Einsum { spec } => ("einsum", format!("einsum\\n{}", escape_label(spec))),
270 OpType::ElemUnary { op } => ("elem_unary", format!("{}(·)", escape_label(op))),
271 OpType::ElemBinary { op } => ("elem_binary", format!("{}(·,·)", escape_label(op))),
272 OpType::Reduce { op, axes } => ("reduce", format!("{}(axes={:?})", escape_label(op), axes)),
273 };
274
275 let label = if options.show_node_ids {
276 format!("{}\\n[op_{}]", op_label, idx)
277 } else {
278 op_label
279 };
280
281 let is_highlighted = options.highlight_nodes.contains(&idx);
282 let color = if is_highlighted {
283 "orange"
284 } else {
285 match op_type {
286 "einsum" => "lightcyan",
287 "elem_unary" => "lightgreen",
288 "elem_binary" => "lightyellow",
289 "reduce" => "lightpink",
290 _ => "white",
291 }
292 };
293
294 writeln!(
295 writer,
296 " op_{} [label=\"{}\", shape=ellipse, style=filled, fillcolor={}];",
297 idx, label, color
298 )?;
299
300 Ok(())
301}
302
303fn escape_label(s: &str) -> String {
305 s.replace('\\', "\\\\")
306 .replace('"', "\\\"")
307 .replace('\n', "\\n")
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use crate::{EinsumGraph, EinsumNode};
314
315 #[test]
316 fn test_export_empty_graph() {
317 let graph = EinsumGraph::new();
318 let dot = export_to_dot(&graph);
319 assert!(dot.contains("digraph EinsumGraph"));
320 }
321
322 #[test]
323 fn test_export_simple_operation() {
324 let mut graph = EinsumGraph::new();
325 let t0 = graph.add_tensor("input".to_string());
326 let t1 = graph.add_tensor("output".to_string());
327
328 let node = EinsumNode::elem_unary("relu", t0, t1);
329 graph.add_node(node).unwrap();
330
331 let dot = export_to_dot(&graph);
332 assert!(dot.contains("relu"));
333 assert!(dot.contains("tensor_0"));
334 assert!(dot.contains("tensor_1"));
335 assert!(dot.contains("op_0"));
336 }
337
338 #[test]
339 fn test_export_with_einsum() {
340 let mut graph = EinsumGraph::new();
341 let t0 = graph.add_tensor("A".to_string());
342 let t1 = graph.add_tensor("B".to_string());
343 let t2 = graph.add_tensor("C".to_string());
344
345 let node = EinsumNode::einsum("ij,jk->ik", vec![t0, t1], vec![t2]);
346 graph.add_node(node).unwrap();
347
348 let dot = export_to_dot(&graph);
349 assert!(dot.contains("einsum"));
350 assert!(dot.contains("ij,jk->ik"));
351 }
352
353 #[test]
354 fn test_export_with_options() {
355 let mut graph = EinsumGraph::new();
356 let t0 = graph.add_tensor("x".to_string());
357 let t1 = graph.add_tensor("y".to_string());
358
359 let node = EinsumNode::elem_unary("sigmoid", t0, t1);
360 graph.add_node(node).unwrap();
361
362 let options = DotExportOptions {
363 show_tensor_ids: true,
364 show_node_ids: true,
365 horizontal_layout: true,
366 ..Default::default()
367 };
368
369 let dot = export_to_dot_with_options(&graph, &options);
370 assert!(dot.contains("rankdir=LR"));
371 assert!(dot.contains("[0]")); assert!(dot.contains("[op_0]")); }
374
375 #[test]
376 fn test_export_with_clustering() {
377 let mut graph = EinsumGraph::new();
378 let t0 = graph.add_tensor("a".to_string());
379 let t1 = graph.add_tensor("b".to_string());
380 let t2 = graph.add_tensor("c".to_string());
381 let t3 = graph.add_tensor("d".to_string());
382
383 graph
384 .add_node(EinsumNode::elem_unary("relu", t0, t1))
385 .unwrap();
386 graph
387 .add_node(EinsumNode::elem_unary("sigmoid", t1, t2))
388 .unwrap();
389 graph
390 .add_node(EinsumNode::elem_binary("add", t2, t0, t3))
391 .unwrap();
392
393 let options = DotExportOptions {
394 cluster_by_operation: true,
395 ..Default::default()
396 };
397
398 let dot = export_to_dot_with_options(&graph, &options);
399 assert!(dot.contains("subgraph cluster_elem_unary"));
400 assert!(dot.contains("subgraph cluster_elem_binary"));
401 }
402
403 #[test]
404 fn test_export_with_highlights() {
405 let mut graph = EinsumGraph::new();
406 let t0 = graph.add_tensor("input".to_string());
407 let t1 = graph.add_tensor("hidden".to_string());
408 let t2 = graph.add_tensor("output".to_string());
409
410 graph
411 .add_node(EinsumNode::elem_unary("relu", t0, t1))
412 .unwrap();
413 graph
414 .add_node(EinsumNode::elem_unary("softmax", t1, t2))
415 .unwrap();
416
417 let options = DotExportOptions {
418 highlight_tensors: vec!["output".to_string()],
419 highlight_nodes: vec![0],
420 ..Default::default()
421 };
422
423 let dot = export_to_dot_with_options(&graph, &options);
424 assert!(dot.contains("red")); assert!(dot.contains("orange")); }
427
428 #[test]
429 fn test_label_escaping() {
430 assert_eq!(escape_label("hello\"world"), "hello\\\"world");
431 assert_eq!(escape_label("line1\nline2"), "line1\\nline2");
432 assert_eq!(escape_label("path\\to\\file"), "path\\\\to\\\\file");
433 }
434
435 #[test]
436 fn test_complex_graph_export() {
437 let mut graph = EinsumGraph::new();
438
439 let a = graph.add_tensor("a".to_string());
441 let b = graph.add_tensor("b".to_string());
442 let c = graph.add_tensor("c".to_string());
443 let sum = graph.add_tensor("sum".to_string());
444 let result = graph.add_tensor("result".to_string());
445
446 graph.inputs = vec![a, b, c];
447 graph.outputs = vec![result];
448
449 graph
450 .add_node(EinsumNode::elem_binary("add", a, b, sum))
451 .unwrap();
452 graph
453 .add_node(EinsumNode::elem_binary("multiply", sum, c, result))
454 .unwrap();
455
456 let dot = export_to_dot(&graph);
457
458 assert!(dot.contains("tensor_0")); assert!(dot.contains("tensor_1")); assert!(dot.contains("tensor_2")); assert!(dot.contains("tensor_3")); assert!(dot.contains("tensor_4")); assert!(dot.contains("op_0")); assert!(dot.contains("op_1")); assert!(dot.contains("tensor_0 -> op_0")); assert!(dot.contains("tensor_1 -> op_0")); assert!(dot.contains("op_0 -> tensor_3")); assert!(dot.contains("tensor_3 -> op_1")); assert!(dot.contains("tensor_2 -> op_1")); assert!(dot.contains("op_1 -> tensor_4")); }
475}