Skip to main content

trustformers_debug/
graph_visualizer.rs

1//! Computation graph visualization tools
2//!
3//! This module provides tools to visualize the computation graph of neural networks,
4//! including layer connections, tensor shapes, and operation flows.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs;
10use std::path::Path;
11
12/// Graph visualizer for computation graphs
13#[derive(Debug)]
14pub struct GraphVisualizer {
15    /// Graph definition
16    graph: ComputationGraph,
17    /// Visualization config
18    config: GraphVisualizerConfig,
19}
20
21/// Configuration for graph visualization
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct GraphVisualizerConfig {
24    /// Include tensor shapes in visualization
25    pub show_shapes: bool,
26    /// Include data types in visualization
27    pub show_dtypes: bool,
28    /// Include operation attributes
29    pub show_attributes: bool,
30    /// Layout direction (TB=top-to-bottom, LR=left-to-right)
31    pub layout_direction: LayoutDirection,
32    /// Maximum depth to visualize (-1 for unlimited)
33    pub max_depth: i32,
34    /// Color scheme
35    pub color_scheme: GraphColorScheme,
36}
37
38/// Layout direction for graph visualization
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
40pub enum LayoutDirection {
41    /// Top to bottom
42    TopToBottom,
43    /// Left to right
44    LeftToRight,
45    /// Bottom to top
46    BottomToTop,
47    /// Right to left
48    RightToLeft,
49}
50
51/// Color scheme for graph nodes
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
53pub enum GraphColorScheme {
54    /// Default colors
55    Default,
56    /// By layer type
57    ByLayerType,
58    /// By computational cost
59    ByCost,
60    /// By data flow
61    ByDataFlow,
62}
63
64impl Default for GraphVisualizerConfig {
65    fn default() -> Self {
66        Self {
67            show_shapes: true,
68            show_dtypes: true,
69            show_attributes: false,
70            layout_direction: LayoutDirection::TopToBottom,
71            max_depth: -1,
72            color_scheme: GraphColorScheme::ByLayerType,
73        }
74    }
75}
76
77/// Computation graph definition
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct ComputationGraph {
80    /// Graph name
81    pub name: String,
82    /// Graph nodes
83    pub nodes: Vec<GraphNode>,
84    /// Graph edges
85    pub edges: Vec<GraphEdge>,
86    /// Input nodes
87    pub inputs: Vec<String>,
88    /// Output nodes
89    pub outputs: Vec<String>,
90}
91
92/// Graph node representing an operation or tensor
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct GraphNode {
95    /// Node ID
96    pub id: String,
97    /// Node label (display name)
98    pub label: String,
99    /// Operation type
100    pub op_type: String,
101    /// Tensor shape (if applicable)
102    pub shape: Option<Vec<i64>>,
103    /// Data type
104    pub dtype: Option<String>,
105    /// Node attributes
106    pub attributes: HashMap<String, String>,
107    /// Node depth in the graph
108    pub depth: usize,
109}
110
111/// Graph edge representing data flow
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct GraphEdge {
114    /// Source node ID
115    pub from: String,
116    /// Target node ID
117    pub to: String,
118    /// Edge label (optional)
119    pub label: Option<String>,
120    /// Tensor shape along this edge
121    pub shape: Option<Vec<i64>>,
122}
123
124impl GraphVisualizer {
125    /// Create a new graph visualizer
126    ///
127    /// # Arguments
128    ///
129    /// * `graph_name` - Name of the computation graph
130    ///
131    /// # Example
132    ///
133    /// ```
134    /// use trustformers_debug::GraphVisualizer;
135    ///
136    /// let visualizer = GraphVisualizer::new("my_model");
137    /// ```
138    pub fn new(graph_name: &str) -> Self {
139        let graph = ComputationGraph {
140            name: graph_name.to_string(),
141            nodes: Vec::new(),
142            edges: Vec::new(),
143            inputs: Vec::new(),
144            outputs: Vec::new(),
145        };
146
147        Self {
148            graph,
149            config: GraphVisualizerConfig::default(),
150        }
151    }
152
153    /// Create a graph visualizer with custom configuration
154    pub fn with_config(graph_name: &str, config: GraphVisualizerConfig) -> Self {
155        let graph = ComputationGraph {
156            name: graph_name.to_string(),
157            nodes: Vec::new(),
158            edges: Vec::new(),
159            inputs: Vec::new(),
160            outputs: Vec::new(),
161        };
162
163        Self { graph, config }
164    }
165
166    /// Add a node to the graph
167    ///
168    /// # Example
169    ///
170    /// ```
171    /// # use trustformers_debug::GraphVisualizer;
172    /// # use std::collections::HashMap;
173    /// # let mut visualizer = GraphVisualizer::new("model");
174    /// visualizer.add_node(
175    ///     "layer1",
176    ///     "Linear Layer 1",
177    ///     "Linear",
178    ///     Some(vec![10, 20]),
179    ///     Some("float32".to_string()),
180    ///     HashMap::new(),
181    /// );
182    /// ```
183    pub fn add_node(
184        &mut self,
185        id: &str,
186        label: &str,
187        op_type: &str,
188        shape: Option<Vec<i64>>,
189        dtype: Option<String>,
190        attributes: HashMap<String, String>,
191    ) {
192        let node = GraphNode {
193            id: id.to_string(),
194            label: label.to_string(),
195            op_type: op_type.to_string(),
196            shape,
197            dtype,
198            attributes,
199            depth: 0, // Will be computed later
200        };
201
202        self.graph.nodes.push(node);
203    }
204
205    /// Add an edge to the graph
206    pub fn add_edge(
207        &mut self,
208        from: &str,
209        to: &str,
210        label: Option<String>,
211        shape: Option<Vec<i64>>,
212    ) {
213        let edge = GraphEdge {
214            from: from.to_string(),
215            to: to.to_string(),
216            label,
217            shape,
218        };
219
220        self.graph.edges.push(edge);
221    }
222
223    /// Mark a node as an input
224    pub fn mark_input(&mut self, node_id: &str) {
225        if !self.graph.inputs.contains(&node_id.to_string()) {
226            self.graph.inputs.push(node_id.to_string());
227        }
228    }
229
230    /// Mark a node as an output
231    pub fn mark_output(&mut self, node_id: &str) {
232        if !self.graph.outputs.contains(&node_id.to_string()) {
233            self.graph.outputs.push(node_id.to_string());
234        }
235    }
236
237    /// Compute node depths in the graph
238    fn compute_depths(&mut self) {
239        // Build adjacency list
240        let mut adjacency: HashMap<String, Vec<String>> = HashMap::new();
241        for edge in &self.graph.edges {
242            adjacency.entry(edge.from.clone()).or_default().push(edge.to.clone());
243        }
244
245        // BFS from input nodes
246        let mut depths: HashMap<String, usize> = HashMap::new();
247        let mut queue: Vec<(String, usize)> = Vec::new();
248
249        for input in &self.graph.inputs {
250            queue.push((input.clone(), 0));
251            depths.insert(input.clone(), 0);
252        }
253
254        while let Some((node_id, depth)) = queue.pop() {
255            if let Some(neighbors) = adjacency.get(&node_id) {
256                for neighbor in neighbors {
257                    let new_depth = depth + 1;
258                    if !depths.contains_key(neighbor) || depths[neighbor] < new_depth {
259                        depths.insert(neighbor.clone(), new_depth);
260                        queue.push((neighbor.clone(), new_depth));
261                    }
262                }
263            }
264        }
265
266        // Update node depths
267        for node in &mut self.graph.nodes {
268            node.depth = *depths.get(&node.id).unwrap_or(&0);
269        }
270    }
271
272    /// Export graph to GraphViz DOT format
273    ///
274    /// # Example
275    ///
276    /// ```no_run
277    /// # use trustformers_debug::GraphVisualizer;
278    /// # let mut visualizer = GraphVisualizer::new("model");
279    /// visualizer.export_to_dot("model.dot").unwrap();
280    /// ```
281    pub fn export_to_dot<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
282        self.compute_depths();
283
284        let mut dot = String::from("digraph {\n");
285
286        // Graph attributes
287        let direction = match self.config.layout_direction {
288            LayoutDirection::TopToBottom => "TB",
289            LayoutDirection::LeftToRight => "LR",
290            LayoutDirection::BottomToTop => "BT",
291            LayoutDirection::RightToLeft => "RL",
292        };
293        dot.push_str(&format!("  rankdir={};\n", direction));
294        dot.push_str("  node [shape=box, style=rounded];\n\n");
295
296        // Add nodes
297        for node in &self.graph.nodes {
298            if self.config.max_depth >= 0 && node.depth > self.config.max_depth as usize {
299                continue;
300            }
301
302            let color = self.get_node_color(node);
303            let mut label = node.label.to_string();
304
305            if self.config.show_shapes {
306                if let Some(ref shape) = node.shape {
307                    label.push_str(&format!("\\nshape: {:?}", shape));
308                }
309            }
310
311            if self.config.show_dtypes {
312                if let Some(ref dtype) = node.dtype {
313                    label.push_str(&format!("\\ndtype: {}", dtype));
314                }
315            }
316
317            dot.push_str(&format!(
318                "  \"{}\" [label=\"{}\", fillcolor=\"{}\", style=\"filled,rounded\"];\n",
319                node.id, label, color
320            ));
321        }
322
323        dot.push('\n');
324
325        // Add edges
326        for edge in &self.graph.edges {
327            let mut edge_label = String::new();
328
329            if let Some(ref label) = edge.label {
330                edge_label = label.clone();
331            } else if self.config.show_shapes {
332                if let Some(ref shape) = edge.shape {
333                    edge_label = format!("{:?}", shape);
334                }
335            }
336
337            if !edge_label.is_empty() {
338                dot.push_str(&format!(
339                    "  \"{}\" -> \"{}\" [label=\"{}\"];\n",
340                    edge.from, edge.to, edge_label
341                ));
342            } else {
343                dot.push_str(&format!("  \"{}\" -> \"{}\";\n", edge.from, edge.to));
344            }
345        }
346
347        dot.push_str("}\n");
348
349        fs::write(path, dot)?;
350        Ok(())
351    }
352
353    /// Get color for a node based on color scheme
354    fn get_node_color(&self, node: &GraphNode) -> &'static str {
355        match self.config.color_scheme {
356            GraphColorScheme::Default => "lightblue",
357            GraphColorScheme::ByLayerType => match node.op_type.as_str() {
358                "Linear" | "Dense" => "lightblue",
359                "Conv2d" | "Conv1d" => "lightgreen",
360                "BatchNorm" | "LayerNorm" => "lightyellow",
361                "ReLU" | "GELU" | "Softmax" => "lightcoral",
362                "Dropout" => "lightgray",
363                "Attention" | "MultiHeadAttention" => "plum",
364                _ => "white",
365            },
366            GraphColorScheme::ByCost => {
367                // Simplified: use depth as proxy for computational cost
368                if node.depth > 10 {
369                    "darkred"
370                } else if node.depth > 5 {
371                    "orange"
372                } else {
373                    "lightgreen"
374                }
375            },
376            GraphColorScheme::ByDataFlow => {
377                if self.graph.inputs.contains(&node.id) {
378                    "lightgreen"
379                } else if self.graph.outputs.contains(&node.id) {
380                    "lightcoral"
381                } else {
382                    "lightblue"
383                }
384            },
385        }
386    }
387
388    /// Export graph to JSON format
389    pub fn export_to_json<P: AsRef<Path>>(&self, path: P) -> Result<()> {
390        let json = serde_json::to_string_pretty(&self.graph)?;
391        fs::write(path, json)?;
392        Ok(())
393    }
394
395    /// Get statistics about the graph
396    pub fn statistics(&self) -> GraphStatistics {
397        let num_nodes = self.graph.nodes.len();
398        let num_edges = self.graph.edges.len();
399
400        let op_type_counts: HashMap<String, usize> =
401            self.graph.nodes.iter().fold(HashMap::new(), |mut acc, node| {
402                *acc.entry(node.op_type.clone()).or_insert(0) += 1;
403                acc
404            });
405
406        let max_depth = self.graph.nodes.iter().map(|n| n.depth).max().unwrap_or(0);
407
408        GraphStatistics {
409            num_nodes,
410            num_edges,
411            num_inputs: self.graph.inputs.len(),
412            num_outputs: self.graph.outputs.len(),
413            max_depth,
414            op_type_counts,
415        }
416    }
417
418    /// Print a summary of the graph
419    pub fn summary(&self) -> String {
420        let stats = self.statistics();
421
422        let mut output = String::new();
423        output.push_str(&format!("Computation Graph: {}\n", self.graph.name));
424        output.push_str(&"=".repeat(60));
425        output.push('\n');
426        output.push_str(&format!("Nodes: {}\n", stats.num_nodes));
427        output.push_str(&format!("Edges: {}\n", stats.num_edges));
428        output.push_str(&format!("Inputs: {}\n", stats.num_inputs));
429        output.push_str(&format!("Outputs: {}\n", stats.num_outputs));
430        output.push_str(&format!("Max Depth: {}\n", stats.max_depth));
431
432        output.push_str("\nOperation Types:\n");
433        for (op_type, count) in &stats.op_type_counts {
434            output.push_str(&format!("  {}: {}\n", op_type, count));
435        }
436
437        output
438    }
439}
440
441/// Graph statistics
442#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct GraphStatistics {
444    /// Number of nodes
445    pub num_nodes: usize,
446    /// Number of edges
447    pub num_edges: usize,
448    /// Number of input nodes
449    pub num_inputs: usize,
450    /// Number of output nodes
451    pub num_outputs: usize,
452    /// Maximum depth
453    pub max_depth: usize,
454    /// Operation type counts
455    pub op_type_counts: HashMap<String, usize>,
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461    use std::env;
462
463    #[test]
464    fn test_graph_visualizer_creation() {
465        let visualizer = GraphVisualizer::new("test_graph");
466        assert_eq!(visualizer.graph.name, "test_graph");
467        assert_eq!(visualizer.graph.nodes.len(), 0);
468    }
469
470    #[test]
471    fn test_add_node() {
472        let mut visualizer = GraphVisualizer::new("test");
473
474        visualizer.add_node(
475            "node1",
476            "Layer 1",
477            "Linear",
478            Some(vec![10, 20]),
479            Some("float32".to_string()),
480            HashMap::new(),
481        );
482
483        assert_eq!(visualizer.graph.nodes.len(), 1);
484        assert_eq!(visualizer.graph.nodes[0].id, "node1");
485    }
486
487    #[test]
488    fn test_add_edge() {
489        let mut visualizer = GraphVisualizer::new("test");
490
491        visualizer.add_node("node1", "N1", "Linear", None, None, HashMap::new());
492        visualizer.add_node("node2", "N2", "ReLU", None, None, HashMap::new());
493        visualizer.add_edge("node1", "node2", None, Some(vec![10, 20]));
494
495        assert_eq!(visualizer.graph.edges.len(), 1);
496        assert_eq!(visualizer.graph.edges[0].from, "node1");
497        assert_eq!(visualizer.graph.edges[0].to, "node2");
498    }
499
500    #[test]
501    fn test_mark_input_output() {
502        let mut visualizer = GraphVisualizer::new("test");
503
504        visualizer.add_node("input", "Input", "Input", None, None, HashMap::new());
505        visualizer.add_node("output", "Output", "Output", None, None, HashMap::new());
506
507        visualizer.mark_input("input");
508        visualizer.mark_output("output");
509
510        assert_eq!(visualizer.graph.inputs.len(), 1);
511        assert_eq!(visualizer.graph.outputs.len(), 1);
512    }
513
514    #[test]
515    fn test_export_to_dot() {
516        let temp_dir = env::temp_dir();
517        let output_path = temp_dir.join("test_graph.dot");
518
519        let mut visualizer = GraphVisualizer::new("test");
520
521        visualizer.add_node("input", "Input", "Input", None, None, HashMap::new());
522        visualizer.add_node(
523            "layer1",
524            "Linear",
525            "Linear",
526            Some(vec![10, 20]),
527            None,
528            HashMap::new(),
529        );
530        visualizer.add_edge("input", "layer1", None, None);
531
532        visualizer.mark_input("input");
533
534        visualizer.export_to_dot(&output_path).unwrap();
535        assert!(output_path.exists());
536
537        // Clean up
538        let _ = fs::remove_file(output_path);
539    }
540
541    #[test]
542    fn test_export_to_json() {
543        let temp_dir = env::temp_dir();
544        let output_path = temp_dir.join("test_graph.json");
545
546        let mut visualizer = GraphVisualizer::new("test");
547        visualizer.add_node("node1", "N1", "Linear", None, None, HashMap::new());
548
549        visualizer.export_to_json(&output_path).unwrap();
550        assert!(output_path.exists());
551
552        // Clean up
553        let _ = fs::remove_file(output_path);
554    }
555
556    #[test]
557    fn test_statistics() {
558        let mut visualizer = GraphVisualizer::new("test");
559
560        visualizer.add_node("n1", "N1", "Linear", None, None, HashMap::new());
561        visualizer.add_node("n2", "N2", "Linear", None, None, HashMap::new());
562        visualizer.add_node("n3", "N3", "ReLU", None, None, HashMap::new());
563
564        visualizer.add_edge("n1", "n2", None, None);
565        visualizer.add_edge("n2", "n3", None, None);
566
567        visualizer.mark_input("n1");
568        visualizer.mark_output("n3");
569
570        let stats = visualizer.statistics();
571
572        assert_eq!(stats.num_nodes, 3);
573        assert_eq!(stats.num_edges, 2);
574        assert_eq!(stats.num_inputs, 1);
575        assert_eq!(stats.num_outputs, 1);
576    }
577
578    #[test]
579    fn test_summary() {
580        let mut visualizer = GraphVisualizer::new("test_model");
581
582        visualizer.add_node("input", "Input", "Input", None, None, HashMap::new());
583        visualizer.add_node("layer1", "Linear", "Linear", None, None, HashMap::new());
584
585        let summary = visualizer.summary();
586        assert!(summary.contains("test_model"));
587        assert!(summary.contains("Nodes: 2"));
588    }
589
590    #[test]
591    fn test_compute_depths() {
592        let mut visualizer = GraphVisualizer::new("test");
593
594        visualizer.add_node("input", "Input", "Input", None, None, HashMap::new());
595        visualizer.add_node("layer1", "L1", "Linear", None, None, HashMap::new());
596        visualizer.add_node("layer2", "L2", "ReLU", None, None, HashMap::new());
597
598        visualizer.add_edge("input", "layer1", None, None);
599        visualizer.add_edge("layer1", "layer2", None, None);
600
601        visualizer.mark_input("input");
602
603        visualizer.compute_depths();
604
605        assert_eq!(visualizer.graph.nodes[0].depth, 0);
606        assert_eq!(visualizer.graph.nodes[1].depth, 1);
607        assert_eq!(visualizer.graph.nodes[2].depth, 2);
608    }
609}