1use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs;
10use std::path::Path;
11
12#[derive(Debug)]
14pub struct GraphVisualizer {
15 graph: ComputationGraph,
17 config: GraphVisualizerConfig,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct GraphVisualizerConfig {
24 pub show_shapes: bool,
26 pub show_dtypes: bool,
28 pub show_attributes: bool,
30 pub layout_direction: LayoutDirection,
32 pub max_depth: i32,
34 pub color_scheme: GraphColorScheme,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
40pub enum LayoutDirection {
41 TopToBottom,
43 LeftToRight,
45 BottomToTop,
47 RightToLeft,
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
53pub enum GraphColorScheme {
54 Default,
56 ByLayerType,
58 ByCost,
60 ByDataFlow,
62}
63
64impl Default for GraphVisualizerConfig {
65 fn default() -> Self {
66 Self {
67 show_shapes: true,
68 show_dtypes: true,
69 show_attributes: false,
70 layout_direction: LayoutDirection::TopToBottom,
71 max_depth: -1,
72 color_scheme: GraphColorScheme::ByLayerType,
73 }
74 }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct ComputationGraph {
80 pub name: String,
82 pub nodes: Vec<GraphNode>,
84 pub edges: Vec<GraphEdge>,
86 pub inputs: Vec<String>,
88 pub outputs: Vec<String>,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct GraphNode {
95 pub id: String,
97 pub label: String,
99 pub op_type: String,
101 pub shape: Option<Vec<i64>>,
103 pub dtype: Option<String>,
105 pub attributes: HashMap<String, String>,
107 pub depth: usize,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct GraphEdge {
114 pub from: String,
116 pub to: String,
118 pub label: Option<String>,
120 pub shape: Option<Vec<i64>>,
122}
123
124impl GraphVisualizer {
125 pub fn new(graph_name: &str) -> Self {
139 let graph = ComputationGraph {
140 name: graph_name.to_string(),
141 nodes: Vec::new(),
142 edges: Vec::new(),
143 inputs: Vec::new(),
144 outputs: Vec::new(),
145 };
146
147 Self {
148 graph,
149 config: GraphVisualizerConfig::default(),
150 }
151 }
152
153 pub fn with_config(graph_name: &str, config: GraphVisualizerConfig) -> Self {
155 let graph = ComputationGraph {
156 name: graph_name.to_string(),
157 nodes: Vec::new(),
158 edges: Vec::new(),
159 inputs: Vec::new(),
160 outputs: Vec::new(),
161 };
162
163 Self { graph, config }
164 }
165
166 pub fn add_node(
184 &mut self,
185 id: &str,
186 label: &str,
187 op_type: &str,
188 shape: Option<Vec<i64>>,
189 dtype: Option<String>,
190 attributes: HashMap<String, String>,
191 ) {
192 let node = GraphNode {
193 id: id.to_string(),
194 label: label.to_string(),
195 op_type: op_type.to_string(),
196 shape,
197 dtype,
198 attributes,
199 depth: 0, };
201
202 self.graph.nodes.push(node);
203 }
204
205 pub fn add_edge(
207 &mut self,
208 from: &str,
209 to: &str,
210 label: Option<String>,
211 shape: Option<Vec<i64>>,
212 ) {
213 let edge = GraphEdge {
214 from: from.to_string(),
215 to: to.to_string(),
216 label,
217 shape,
218 };
219
220 self.graph.edges.push(edge);
221 }
222
223 pub fn mark_input(&mut self, node_id: &str) {
225 if !self.graph.inputs.contains(&node_id.to_string()) {
226 self.graph.inputs.push(node_id.to_string());
227 }
228 }
229
230 pub fn mark_output(&mut self, node_id: &str) {
232 if !self.graph.outputs.contains(&node_id.to_string()) {
233 self.graph.outputs.push(node_id.to_string());
234 }
235 }
236
237 fn compute_depths(&mut self) {
239 let mut adjacency: HashMap<String, Vec<String>> = HashMap::new();
241 for edge in &self.graph.edges {
242 adjacency.entry(edge.from.clone()).or_default().push(edge.to.clone());
243 }
244
245 let mut depths: HashMap<String, usize> = HashMap::new();
247 let mut queue: Vec<(String, usize)> = Vec::new();
248
249 for input in &self.graph.inputs {
250 queue.push((input.clone(), 0));
251 depths.insert(input.clone(), 0);
252 }
253
254 while let Some((node_id, depth)) = queue.pop() {
255 if let Some(neighbors) = adjacency.get(&node_id) {
256 for neighbor in neighbors {
257 let new_depth = depth + 1;
258 if !depths.contains_key(neighbor) || depths[neighbor] < new_depth {
259 depths.insert(neighbor.clone(), new_depth);
260 queue.push((neighbor.clone(), new_depth));
261 }
262 }
263 }
264 }
265
266 for node in &mut self.graph.nodes {
268 node.depth = *depths.get(&node.id).unwrap_or(&0);
269 }
270 }
271
272 pub fn export_to_dot<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
282 self.compute_depths();
283
284 let mut dot = String::from("digraph {\n");
285
286 let direction = match self.config.layout_direction {
288 LayoutDirection::TopToBottom => "TB",
289 LayoutDirection::LeftToRight => "LR",
290 LayoutDirection::BottomToTop => "BT",
291 LayoutDirection::RightToLeft => "RL",
292 };
293 dot.push_str(&format!(" rankdir={};\n", direction));
294 dot.push_str(" node [shape=box, style=rounded];\n\n");
295
296 for node in &self.graph.nodes {
298 if self.config.max_depth >= 0 && node.depth > self.config.max_depth as usize {
299 continue;
300 }
301
302 let color = self.get_node_color(node);
303 let mut label = node.label.to_string();
304
305 if self.config.show_shapes {
306 if let Some(ref shape) = node.shape {
307 label.push_str(&format!("\\nshape: {:?}", shape));
308 }
309 }
310
311 if self.config.show_dtypes {
312 if let Some(ref dtype) = node.dtype {
313 label.push_str(&format!("\\ndtype: {}", dtype));
314 }
315 }
316
317 dot.push_str(&format!(
318 " \"{}\" [label=\"{}\", fillcolor=\"{}\", style=\"filled,rounded\"];\n",
319 node.id, label, color
320 ));
321 }
322
323 dot.push('\n');
324
325 for edge in &self.graph.edges {
327 let mut edge_label = String::new();
328
329 if let Some(ref label) = edge.label {
330 edge_label = label.clone();
331 } else if self.config.show_shapes {
332 if let Some(ref shape) = edge.shape {
333 edge_label = format!("{:?}", shape);
334 }
335 }
336
337 if !edge_label.is_empty() {
338 dot.push_str(&format!(
339 " \"{}\" -> \"{}\" [label=\"{}\"];\n",
340 edge.from, edge.to, edge_label
341 ));
342 } else {
343 dot.push_str(&format!(" \"{}\" -> \"{}\";\n", edge.from, edge.to));
344 }
345 }
346
347 dot.push_str("}\n");
348
349 fs::write(path, dot)?;
350 Ok(())
351 }
352
353 fn get_node_color(&self, node: &GraphNode) -> &'static str {
355 match self.config.color_scheme {
356 GraphColorScheme::Default => "lightblue",
357 GraphColorScheme::ByLayerType => match node.op_type.as_str() {
358 "Linear" | "Dense" => "lightblue",
359 "Conv2d" | "Conv1d" => "lightgreen",
360 "BatchNorm" | "LayerNorm" => "lightyellow",
361 "ReLU" | "GELU" | "Softmax" => "lightcoral",
362 "Dropout" => "lightgray",
363 "Attention" | "MultiHeadAttention" => "plum",
364 _ => "white",
365 },
366 GraphColorScheme::ByCost => {
367 if node.depth > 10 {
369 "darkred"
370 } else if node.depth > 5 {
371 "orange"
372 } else {
373 "lightgreen"
374 }
375 },
376 GraphColorScheme::ByDataFlow => {
377 if self.graph.inputs.contains(&node.id) {
378 "lightgreen"
379 } else if self.graph.outputs.contains(&node.id) {
380 "lightcoral"
381 } else {
382 "lightblue"
383 }
384 },
385 }
386 }
387
388 pub fn export_to_json<P: AsRef<Path>>(&self, path: P) -> Result<()> {
390 let json = serde_json::to_string_pretty(&self.graph)?;
391 fs::write(path, json)?;
392 Ok(())
393 }
394
395 pub fn statistics(&self) -> GraphStatistics {
397 let num_nodes = self.graph.nodes.len();
398 let num_edges = self.graph.edges.len();
399
400 let op_type_counts: HashMap<String, usize> =
401 self.graph.nodes.iter().fold(HashMap::new(), |mut acc, node| {
402 *acc.entry(node.op_type.clone()).or_insert(0) += 1;
403 acc
404 });
405
406 let max_depth = self.graph.nodes.iter().map(|n| n.depth).max().unwrap_or(0);
407
408 GraphStatistics {
409 num_nodes,
410 num_edges,
411 num_inputs: self.graph.inputs.len(),
412 num_outputs: self.graph.outputs.len(),
413 max_depth,
414 op_type_counts,
415 }
416 }
417
418 pub fn summary(&self) -> String {
420 let stats = self.statistics();
421
422 let mut output = String::new();
423 output.push_str(&format!("Computation Graph: {}\n", self.graph.name));
424 output.push_str(&"=".repeat(60));
425 output.push('\n');
426 output.push_str(&format!("Nodes: {}\n", stats.num_nodes));
427 output.push_str(&format!("Edges: {}\n", stats.num_edges));
428 output.push_str(&format!("Inputs: {}\n", stats.num_inputs));
429 output.push_str(&format!("Outputs: {}\n", stats.num_outputs));
430 output.push_str(&format!("Max Depth: {}\n", stats.max_depth));
431
432 output.push_str("\nOperation Types:\n");
433 for (op_type, count) in &stats.op_type_counts {
434 output.push_str(&format!(" {}: {}\n", op_type, count));
435 }
436
437 output
438 }
439}
440
441#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct GraphStatistics {
444 pub num_nodes: usize,
446 pub num_edges: usize,
448 pub num_inputs: usize,
450 pub num_outputs: usize,
452 pub max_depth: usize,
454 pub op_type_counts: HashMap<String, usize>,
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use std::env;
462
463 #[test]
464 fn test_graph_visualizer_creation() {
465 let visualizer = GraphVisualizer::new("test_graph");
466 assert_eq!(visualizer.graph.name, "test_graph");
467 assert_eq!(visualizer.graph.nodes.len(), 0);
468 }
469
470 #[test]
471 fn test_add_node() {
472 let mut visualizer = GraphVisualizer::new("test");
473
474 visualizer.add_node(
475 "node1",
476 "Layer 1",
477 "Linear",
478 Some(vec![10, 20]),
479 Some("float32".to_string()),
480 HashMap::new(),
481 );
482
483 assert_eq!(visualizer.graph.nodes.len(), 1);
484 assert_eq!(visualizer.graph.nodes[0].id, "node1");
485 }
486
487 #[test]
488 fn test_add_edge() {
489 let mut visualizer = GraphVisualizer::new("test");
490
491 visualizer.add_node("node1", "N1", "Linear", None, None, HashMap::new());
492 visualizer.add_node("node2", "N2", "ReLU", None, None, HashMap::new());
493 visualizer.add_edge("node1", "node2", None, Some(vec![10, 20]));
494
495 assert_eq!(visualizer.graph.edges.len(), 1);
496 assert_eq!(visualizer.graph.edges[0].from, "node1");
497 assert_eq!(visualizer.graph.edges[0].to, "node2");
498 }
499
500 #[test]
501 fn test_mark_input_output() {
502 let mut visualizer = GraphVisualizer::new("test");
503
504 visualizer.add_node("input", "Input", "Input", None, None, HashMap::new());
505 visualizer.add_node("output", "Output", "Output", None, None, HashMap::new());
506
507 visualizer.mark_input("input");
508 visualizer.mark_output("output");
509
510 assert_eq!(visualizer.graph.inputs.len(), 1);
511 assert_eq!(visualizer.graph.outputs.len(), 1);
512 }
513
514 #[test]
515 fn test_export_to_dot() {
516 let temp_dir = env::temp_dir();
517 let output_path = temp_dir.join("test_graph.dot");
518
519 let mut visualizer = GraphVisualizer::new("test");
520
521 visualizer.add_node("input", "Input", "Input", None, None, HashMap::new());
522 visualizer.add_node(
523 "layer1",
524 "Linear",
525 "Linear",
526 Some(vec![10, 20]),
527 None,
528 HashMap::new(),
529 );
530 visualizer.add_edge("input", "layer1", None, None);
531
532 visualizer.mark_input("input");
533
534 visualizer.export_to_dot(&output_path).unwrap();
535 assert!(output_path.exists());
536
537 let _ = fs::remove_file(output_path);
539 }
540
541 #[test]
542 fn test_export_to_json() {
543 let temp_dir = env::temp_dir();
544 let output_path = temp_dir.join("test_graph.json");
545
546 let mut visualizer = GraphVisualizer::new("test");
547 visualizer.add_node("node1", "N1", "Linear", None, None, HashMap::new());
548
549 visualizer.export_to_json(&output_path).unwrap();
550 assert!(output_path.exists());
551
552 let _ = fs::remove_file(output_path);
554 }
555
556 #[test]
557 fn test_statistics() {
558 let mut visualizer = GraphVisualizer::new("test");
559
560 visualizer.add_node("n1", "N1", "Linear", None, None, HashMap::new());
561 visualizer.add_node("n2", "N2", "Linear", None, None, HashMap::new());
562 visualizer.add_node("n3", "N3", "ReLU", None, None, HashMap::new());
563
564 visualizer.add_edge("n1", "n2", None, None);
565 visualizer.add_edge("n2", "n3", None, None);
566
567 visualizer.mark_input("n1");
568 visualizer.mark_output("n3");
569
570 let stats = visualizer.statistics();
571
572 assert_eq!(stats.num_nodes, 3);
573 assert_eq!(stats.num_edges, 2);
574 assert_eq!(stats.num_inputs, 1);
575 assert_eq!(stats.num_outputs, 1);
576 }
577
578 #[test]
579 fn test_summary() {
580 let mut visualizer = GraphVisualizer::new("test_model");
581
582 visualizer.add_node("input", "Input", "Input", None, None, HashMap::new());
583 visualizer.add_node("layer1", "Linear", "Linear", None, None, HashMap::new());
584
585 let summary = visualizer.summary();
586 assert!(summary.contains("test_model"));
587 assert!(summary.contains("Nodes: 2"));
588 }
589
590 #[test]
591 fn test_compute_depths() {
592 let mut visualizer = GraphVisualizer::new("test");
593
594 visualizer.add_node("input", "Input", "Input", None, None, HashMap::new());
595 visualizer.add_node("layer1", "L1", "Linear", None, None, HashMap::new());
596 visualizer.add_node("layer2", "L2", "ReLU", None, None, HashMap::new());
597
598 visualizer.add_edge("input", "layer1", None, None);
599 visualizer.add_edge("layer1", "layer2", None, None);
600
601 visualizer.mark_input("input");
602
603 visualizer.compute_depths();
604
605 assert_eq!(visualizer.graph.nodes[0].depth, 0);
606 assert_eq!(visualizer.graph.nodes[1].depth, 1);
607 assert_eq!(visualizer.graph.nodes[2].depth, 2);
608 }
609}