Skip to main content

trustformers_core/autodiff/
debugger.rs

1//! Computation graph debugger for automatic differentiation
2//!
3//! This module provides debugging and visualization tools for computation graphs,
4//! helping developers understand gradient flow, detect issues, and optimize
5//! automatic differentiation computations.
6
7#![allow(unused_variables)] // Autodiff debugger
8
9use super::graph::{ComputationGraph, GraphNode, NodeId, OperationType};
10use crate::errors::{Result, TrustformersError};
11use crate::tensor::Tensor;
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, HashSet, VecDeque};
14use std::fmt::Write;
15
16/// Computation graph debugger
17pub struct GraphDebugger {
18    /// Configuration for debugging
19    config: DebuggerConfig,
20    /// Analysis results cache
21    analysis_cache: HashMap<String, AnalysisResult>,
22    /// Breakpoints for debugging
23    breakpoints: HashSet<NodeId>,
24}
25
26/// Configuration for the graph debugger
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct DebuggerConfig {
29    /// Maximum number of nodes to display in summaries
30    pub max_display_nodes: usize,
31    /// Whether to show gradient information
32    pub show_gradients: bool,
33    /// Whether to show tensor shapes
34    pub show_shapes: bool,
35    /// Whether to show tensor values (can be verbose)
36    pub show_values: bool,
37    /// Output format for graph visualization
38    pub output_format: GraphOutputFormat,
39    /// Threshold for gradient magnitude warnings
40    pub gradient_magnitude_threshold: f32,
41    /// Threshold for detecting vanishing gradients
42    pub vanishing_gradient_threshold: f32,
43    /// Threshold for detecting exploding gradients
44    pub exploding_gradient_threshold: f32,
45}
46
47/// Output formats for graph visualization
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
49pub enum GraphOutputFormat {
50    /// DOT format for Graphviz
51    Dot,
52    /// ASCII art representation
53    ASCII,
54    /// JSON format for programmatic use
55    JSON,
56    /// HTML with interactive features
57    HTML,
58}
59
60/// Analysis result for computation graph
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct AnalysisResult {
63    /// Total number of nodes
64    pub total_nodes: usize,
65    /// Number of leaf nodes
66    pub leaf_nodes: usize,
67    /// Number of root nodes
68    pub root_nodes: usize,
69    /// Maximum depth of the graph
70    pub max_depth: usize,
71    /// Number of operations by type
72    pub operation_counts: HashMap<String, usize>,
73    /// Gradient flow statistics
74    pub gradient_stats: GradientFlowStats,
75    /// Memory usage estimates
76    pub memory_stats: MemoryStats,
77    /// Potential issues detected
78    pub issues: Vec<GraphIssue>,
79}
80
81/// Statistics about gradient flow
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct GradientFlowStats {
84    /// Nodes with gradients
85    pub nodes_with_gradients: usize,
86    /// Nodes requiring gradients
87    pub nodes_requiring_gradients: usize,
88    /// Average gradient magnitude
89    pub average_gradient_magnitude: f32,
90    /// Maximum gradient magnitude
91    pub max_gradient_magnitude: f32,
92    /// Minimum gradient magnitude
93    pub min_gradient_magnitude: f32,
94    /// Nodes with vanishing gradients
95    pub vanishing_gradient_nodes: Vec<NodeId>,
96    /// Nodes with exploding gradients
97    pub exploding_gradient_nodes: Vec<NodeId>,
98}
99
100/// Memory usage statistics
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct MemoryStats {
103    /// Total memory used by tensors (bytes)
104    pub total_tensor_memory: usize,
105    /// Total memory used by gradients (bytes)
106    pub total_gradient_memory: usize,
107    /// Peak memory usage estimate
108    pub peak_memory_estimate: usize,
109    /// Memory usage by node
110    pub memory_per_node: HashMap<NodeId, usize>,
111}
112
113/// Issues detected in the computation graph
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct GraphIssue {
116    /// Type of issue
117    pub issue_type: IssueType,
118    /// Nodes involved in the issue
119    pub nodes: Vec<NodeId>,
120    /// Description of the issue
121    pub description: String,
122    /// Severity level
123    pub severity: IssueSeverity,
124    /// Suggested fix
125    pub suggestion: Option<String>,
126}
127
128/// Types of issues that can be detected
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
130pub enum IssueType {
131    /// Vanishing gradients
132    VanishingGradients,
133    /// Exploding gradients
134    ExplodingGradients,
135    /// Disconnected subgraphs
136    DisconnectedSubgraph,
137    /// Cycles in the graph
138    CyclicDependency,
139    /// Inefficient operations
140    IneffientOperation,
141    /// Shape mismatches
142    ShapeMismatch,
143    /// Memory issues
144    MemoryIssue,
145    /// Numerical instability
146    NumericalInstability,
147}
148
149/// Severity levels for issues
150#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
151pub enum IssueSeverity {
152    /// Critical issues that will cause failures
153    Critical,
154    /// Warning issues that may cause problems
155    Warning,
156    /// Info issues for optimization
157    Info,
158}
159
160/// Node information for debugging
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct NodeDebugInfo {
163    pub id: NodeId,
164    pub name: Option<String>,
165    pub operation: Option<OperationType>,
166    pub shape: Vec<usize>,
167    pub requires_grad: bool,
168    pub is_leaf: bool,
169    pub has_gradient: bool,
170    pub gradient_magnitude: Option<f32>,
171    pub tensor_magnitude: f32,
172    pub memory_usage: usize,
173    pub parents: Vec<NodeId>,
174    pub children: Vec<NodeId>,
175    pub depth_from_root: usize,
176}
177
178/// Graph traversal information
179#[derive(Debug, Clone)]
180pub struct TraversalInfo {
181    pub visited_nodes: HashSet<NodeId>,
182    pub node_depths: HashMap<NodeId, usize>,
183    pub execution_order: Vec<NodeId>,
184}
185
186impl Default for DebuggerConfig {
187    fn default() -> Self {
188        Self {
189            max_display_nodes: 50,
190            show_gradients: true,
191            show_shapes: true,
192            show_values: false,
193            output_format: GraphOutputFormat::Dot,
194            gradient_magnitude_threshold: 1e-6,
195            vanishing_gradient_threshold: 1e-7,
196            exploding_gradient_threshold: 1e3,
197        }
198    }
199}
200
201impl Default for GraphDebugger {
202    fn default() -> Self {
203        Self::new()
204    }
205}
206
207impl GraphDebugger {
208    /// Create a new graph debugger
209    pub fn new() -> Self {
210        Self {
211            config: DebuggerConfig::default(),
212            analysis_cache: HashMap::new(),
213            breakpoints: HashSet::new(),
214        }
215    }
216
217    /// Create a new graph debugger with custom configuration
218    pub fn with_config(config: DebuggerConfig) -> Self {
219        Self {
220            config,
221            analysis_cache: HashMap::new(),
222            breakpoints: HashSet::new(),
223        }
224    }
225
226    /// Analyze a computation graph
227    pub fn analyze(&mut self, graph: &ComputationGraph) -> Result<AnalysisResult> {
228        let graph_hash = self.compute_graph_hash(graph);
229
230        if let Some(cached_result) = self.analysis_cache.get(&graph_hash) {
231            return Ok(cached_result.clone());
232        }
233
234        let nodes = self.get_all_nodes(graph)?;
235        let total_nodes = nodes.len();
236
237        // Count different types of nodes
238        let leaf_nodes = nodes.iter().filter(|n| n.is_leaf).count();
239        let root_nodes = nodes.iter().filter(|n| n.parents.is_empty()).count();
240
241        // Compute graph depth
242        let max_depth = self.compute_max_depth(graph, &nodes)?;
243
244        // Count operations by type
245        let operation_counts = self.count_operations(&nodes);
246
247        // Analyze gradient flow
248        let gradient_stats = self.analyze_gradient_flow(&nodes)?;
249
250        // Compute memory statistics
251        let memory_stats = self.compute_memory_stats(&nodes)?;
252
253        // Detect issues
254        let issues = self.detect_issues(graph, &nodes, &gradient_stats)?;
255
256        let result = AnalysisResult {
257            total_nodes,
258            leaf_nodes,
259            root_nodes,
260            max_depth,
261            operation_counts,
262            gradient_stats,
263            memory_stats,
264            issues,
265        };
266
267        self.analysis_cache.insert(graph_hash, result.clone());
268        Ok(result)
269    }
270
271    /// Generate a visual representation of the computation graph
272    pub fn visualize(&self, graph: &ComputationGraph) -> Result<String> {
273        match self.config.output_format {
274            GraphOutputFormat::Dot => self.generate_dot_graph(graph),
275            GraphOutputFormat::ASCII => self.generate_ascii_graph(graph),
276            GraphOutputFormat::JSON => self.generate_json_graph(graph),
277            GraphOutputFormat::HTML => self.generate_html_graph(graph),
278        }
279    }
280
281    /// Trace gradient flow from a specific node
282    pub fn trace_gradients(
283        &self,
284        graph: &ComputationGraph,
285        start_node: NodeId,
286    ) -> Result<Vec<NodeDebugInfo>> {
287        let mut trace = Vec::new();
288        let mut visited = HashSet::new();
289        let mut queue = VecDeque::new();
290
291        queue.push_back(start_node);
292
293        while let Some(node_id) = queue.pop_front() {
294            if visited.contains(&node_id) {
295                continue;
296            }
297            visited.insert(node_id);
298
299            let node = self.get_node(graph, node_id)?;
300            let debug_info = self.create_node_debug_info(&node);
301            trace.push(debug_info);
302
303            // Add parent nodes to trace backward through gradients
304            for &parent_id in &node.parents {
305                if !visited.contains(&parent_id) {
306                    queue.push_back(parent_id);
307                }
308            }
309        }
310
311        Ok(trace)
312    }
313
314    /// Set a breakpoint at a specific node
315    pub fn set_breakpoint(&mut self, node_id: NodeId) {
316        self.breakpoints.insert(node_id);
317    }
318
319    /// Remove a breakpoint
320    pub fn remove_breakpoint(&mut self, node_id: NodeId) {
321        self.breakpoints.remove(&node_id);
322    }
323
324    /// Check if execution should break at a node
325    pub fn should_break(&self, node_id: NodeId) -> bool {
326        self.breakpoints.contains(&node_id)
327    }
328
329    /// Get debug information for a specific node
330    pub fn get_node_info(
331        &self,
332        graph: &ComputationGraph,
333        node_id: NodeId,
334    ) -> Result<NodeDebugInfo> {
335        let node = self.get_node(graph, node_id)?;
336        Ok(self.create_node_debug_info(&node))
337    }
338
339    /// Find nodes by name pattern
340    pub fn find_nodes_by_name(
341        &self,
342        graph: &ComputationGraph,
343        pattern: &str,
344    ) -> Result<Vec<NodeId>> {
345        let nodes = self.get_all_nodes(graph)?;
346        let matching_nodes = nodes
347            .iter()
348            .filter(|node| node.name.as_ref().map(|name| name.contains(pattern)).unwrap_or(false))
349            .map(|node| node.id)
350            .collect();
351
352        Ok(matching_nodes)
353    }
354
355    /// Generate a summary report of the computation graph
356    pub fn generate_summary(&mut self, graph: &ComputationGraph) -> Result<String> {
357        let analysis = self.analyze(graph)?;
358        let mut report = String::new();
359
360        writeln!(report, "Computation Graph Summary")?;
361        writeln!(report, "=========================")?;
362        writeln!(report)?;
363
364        writeln!(report, "Graph Structure:")?;
365        writeln!(report, "  Total nodes: {}", analysis.total_nodes)?;
366        writeln!(report, "  Leaf nodes: {}", analysis.leaf_nodes)?;
367        writeln!(report, "  Root nodes: {}", analysis.root_nodes)?;
368        writeln!(report, "  Maximum depth: {}", analysis.max_depth)?;
369        writeln!(report)?;
370
371        writeln!(report, "Operations:")?;
372        for (op_type, count) in &analysis.operation_counts {
373            writeln!(report, "  {}: {}", op_type, count)?;
374        }
375        writeln!(report)?;
376
377        writeln!(report, "Gradient Flow:")?;
378        writeln!(
379            report,
380            "  Nodes with gradients: {}",
381            analysis.gradient_stats.nodes_with_gradients
382        )?;
383        writeln!(
384            report,
385            "  Nodes requiring gradients: {}",
386            analysis.gradient_stats.nodes_requiring_gradients
387        )?;
388        writeln!(
389            report,
390            "  Average gradient magnitude: {:.6}",
391            analysis.gradient_stats.average_gradient_magnitude
392        )?;
393        writeln!(
394            report,
395            "  Max gradient magnitude: {:.6}",
396            analysis.gradient_stats.max_gradient_magnitude
397        )?;
398        writeln!(
399            report,
400            "  Min gradient magnitude: {:.6}",
401            analysis.gradient_stats.min_gradient_magnitude
402        )?;
403        writeln!(report)?;
404
405        writeln!(report, "Memory Usage:")?;
406        writeln!(
407            report,
408            "  Total tensor memory: {} bytes",
409            analysis.memory_stats.total_tensor_memory
410        )?;
411        writeln!(
412            report,
413            "  Total gradient memory: {} bytes",
414            analysis.memory_stats.total_gradient_memory
415        )?;
416        writeln!(
417            report,
418            "  Peak memory estimate: {} bytes",
419            analysis.memory_stats.peak_memory_estimate
420        )?;
421        writeln!(report)?;
422
423        if !analysis.issues.is_empty() {
424            writeln!(report, "Issues Detected:")?;
425            for issue in &analysis.issues {
426                writeln!(
427                    report,
428                    "  [{:?}] {:?}: {}",
429                    issue.severity, issue.issue_type, issue.description
430                )?;
431                if let Some(suggestion) = &issue.suggestion {
432                    writeln!(report, "    Suggestion: {}", suggestion)?;
433                }
434            }
435        } else {
436            writeln!(report, "No issues detected.")?;
437        }
438
439        Ok(report)
440    }
441
442    /// Save debug information to file
443    pub fn save_debug_info(&mut self, graph: &ComputationGraph, path: &str) -> Result<()> {
444        let analysis = self.analyze(graph)?;
445        let json_data = serde_json::to_string_pretty(&analysis)?;
446        std::fs::write(path, json_data)?;
447        Ok(())
448    }
449
450    // Helper methods
451
452    fn get_all_nodes(&self, graph: &ComputationGraph) -> Result<Vec<GraphNode>> {
453        // Access all nodes from the computation graph using the public export method
454        Ok(graph.export_graph().nodes)
455    }
456
457    fn get_node(&self, graph: &ComputationGraph, node_id: NodeId) -> Result<GraphNode> {
458        // Get a specific node from the computation graph
459        let export = graph.export_graph();
460        export.nodes.into_iter().find(|node| node.id == node_id).ok_or_else(|| {
461            TrustformersError::new(crate::errors::ErrorKind::TensorOpError {
462                operation: "get_node".to_string(),
463                reason: format!("Node {} not found in computation graph", node_id),
464            })
465        })
466    }
467
468    fn compute_graph_hash(&self, graph: &ComputationGraph) -> String {
469        // Compute a hash based on graph structure and operations
470        use std::collections::hash_map::DefaultHasher;
471        use std::hash::{Hash, Hasher};
472
473        let mut hasher = DefaultHasher::new();
474
475        // Hash number of nodes
476        graph.num_nodes().hash(&mut hasher);
477
478        // Hash topological order
479        graph.get_topological_order().hash(&mut hasher);
480
481        // Hash each node's structure (operations and connections)
482        let export = graph.export_graph();
483        let mut nodes = export.nodes;
484        nodes.sort_by_key(|node| node.id);
485
486        for node in nodes {
487            node.id.hash(&mut hasher);
488
489            // Hash operation type
490            if let Some(ref op) = node.operation {
491                std::mem::discriminant(op).hash(&mut hasher);
492            }
493
494            // Hash parent connections
495            let mut parents = node.parents.clone();
496            parents.sort();
497            parents.hash(&mut hasher);
498
499            // Hash whether requires grad
500            node.requires_grad.hash(&mut hasher);
501            node.is_leaf.hash(&mut hasher);
502        }
503
504        format!("graph_{:x}", hasher.finish())
505    }
506
507    fn compute_max_depth(&self, graph: &ComputationGraph, nodes: &[GraphNode]) -> Result<usize> {
508        let mut max_depth = 0;
509        let mut visited = HashSet::new();
510
511        for node in nodes {
512            if node.is_leaf {
513                let depth = self.compute_node_depth(graph, node.id, &mut visited)?;
514                max_depth = max_depth.max(depth);
515            }
516        }
517
518        Ok(max_depth)
519    }
520
521    fn compute_node_depth(
522        &self,
523        graph: &ComputationGraph,
524        node_id: NodeId,
525        visited: &mut HashSet<NodeId>,
526    ) -> Result<usize> {
527        if visited.contains(&node_id) {
528            return Ok(0); // Avoid infinite recursion
529        }
530        visited.insert(node_id);
531
532        let node = self.get_node(graph, node_id)?;
533        if node.children.is_empty() {
534            return Ok(0);
535        }
536
537        let mut max_child_depth = 0;
538        for &child_id in &node.children {
539            let child_depth = self.compute_node_depth(graph, child_id, visited)?;
540            max_child_depth = max_child_depth.max(child_depth);
541        }
542
543        Ok(max_child_depth + 1)
544    }
545
546    fn count_operations(&self, nodes: &[GraphNode]) -> HashMap<String, usize> {
547        let mut counts = HashMap::new();
548
549        for node in nodes {
550            if let Some(ref op) = node.operation {
551                let op_name = format!("{:?}", op);
552                *counts.entry(op_name).or_insert(0) += 1;
553            }
554        }
555
556        counts
557    }
558
559    fn analyze_gradient_flow(&self, nodes: &[GraphNode]) -> Result<GradientFlowStats> {
560        let nodes_with_gradients = nodes.iter().filter(|n| n.gradient.is_some()).count();
561        let nodes_requiring_gradients = nodes.iter().filter(|n| n.requires_grad).count();
562
563        let gradient_magnitudes: Vec<f32> = nodes
564            .iter()
565            .filter_map(|node| {
566                node.gradient.as_ref().and_then(|grad| self.compute_tensor_magnitude(grad).ok())
567            })
568            .collect();
569
570        let (average_gradient_magnitude, max_gradient_magnitude, min_gradient_magnitude) =
571            if gradient_magnitudes.is_empty() {
572                (0.0, 0.0, 0.0)
573            } else {
574                let sum: f32 = gradient_magnitudes.iter().sum();
575                let avg = sum / gradient_magnitudes.len() as f32;
576                let max = gradient_magnitudes.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
577                let min = gradient_magnitudes.iter().fold(f32::INFINITY, |a, &b| a.min(b));
578                (avg, max, min)
579            };
580
581        let vanishing_gradient_nodes: Vec<NodeId> = nodes
582            .iter()
583            .filter(|node| {
584                node.gradient
585                    .as_ref()
586                    .and_then(|grad| self.compute_tensor_magnitude(grad).ok())
587                    .map(|mag| mag < self.config.vanishing_gradient_threshold)
588                    .unwrap_or(false)
589            })
590            .map(|node| node.id)
591            .collect();
592
593        let exploding_gradient_nodes: Vec<NodeId> = nodes
594            .iter()
595            .filter(|node| {
596                node.gradient
597                    .as_ref()
598                    .and_then(|grad| self.compute_tensor_magnitude(grad).ok())
599                    .map(|mag| mag > self.config.exploding_gradient_threshold)
600                    .unwrap_or(false)
601            })
602            .map(|node| node.id)
603            .collect();
604
605        Ok(GradientFlowStats {
606            nodes_with_gradients,
607            nodes_requiring_gradients,
608            average_gradient_magnitude,
609            max_gradient_magnitude,
610            min_gradient_magnitude,
611            vanishing_gradient_nodes,
612            exploding_gradient_nodes,
613        })
614    }
615
616    fn compute_memory_stats(&self, nodes: &[GraphNode]) -> Result<MemoryStats> {
617        let mut total_tensor_memory = 0;
618        let mut total_gradient_memory = 0;
619        let mut memory_per_node = HashMap::new();
620
621        for node in nodes {
622            let tensor_memory = node.value.memory_usage();
623            let gradient_memory = node.gradient.as_ref().map(|g| g.memory_usage()).unwrap_or(0);
624
625            total_tensor_memory += tensor_memory;
626            total_gradient_memory += gradient_memory;
627            memory_per_node.insert(node.id, tensor_memory + gradient_memory);
628        }
629
630        let peak_memory_estimate = total_tensor_memory + total_gradient_memory;
631
632        Ok(MemoryStats {
633            total_tensor_memory,
634            total_gradient_memory,
635            peak_memory_estimate,
636            memory_per_node,
637        })
638    }
639
640    fn detect_issues(
641        &self,
642        graph: &ComputationGraph,
643        nodes: &[GraphNode],
644        gradient_stats: &GradientFlowStats,
645    ) -> Result<Vec<GraphIssue>> {
646        let mut issues = Vec::new();
647
648        // Check for vanishing gradients
649        if !gradient_stats.vanishing_gradient_nodes.is_empty() {
650            issues.push(GraphIssue {
651                issue_type: IssueType::VanishingGradients,
652                nodes: gradient_stats.vanishing_gradient_nodes.clone(),
653                description: format!(
654                    "Detected {} nodes with vanishing gradients",
655                    gradient_stats.vanishing_gradient_nodes.len()
656                ),
657                severity: IssueSeverity::Warning,
658                suggestion: Some(
659                    "Consider using gradient clipping or adjusting learning rates".to_string(),
660                ),
661            });
662        }
663
664        // Check for exploding gradients
665        if !gradient_stats.exploding_gradient_nodes.is_empty() {
666            issues.push(GraphIssue {
667                issue_type: IssueType::ExplodingGradients,
668                nodes: gradient_stats.exploding_gradient_nodes.clone(),
669                description: format!(
670                    "Detected {} nodes with exploding gradients",
671                    gradient_stats.exploding_gradient_nodes.len()
672                ),
673                severity: IssueSeverity::Critical,
674                suggestion: Some("Apply gradient clipping to prevent instability".to_string()),
675            });
676        }
677
678        // Check for disconnected subgraphs
679        let disconnected_nodes = self.find_disconnected_nodes(graph, nodes)?;
680        if !disconnected_nodes.is_empty() {
681            issues.push(GraphIssue {
682                issue_type: IssueType::DisconnectedSubgraph,
683                nodes: disconnected_nodes,
684                description: "Found disconnected nodes in the computation graph".to_string(),
685                severity: IssueSeverity::Warning,
686                suggestion: Some("Check that all variables are properly connected".to_string()),
687            });
688        }
689
690        Ok(issues)
691    }
692
693    fn find_disconnected_nodes(
694        &self,
695        graph: &ComputationGraph,
696        nodes: &[GraphNode],
697    ) -> Result<Vec<NodeId>> {
698        // Placeholder implementation
699        Ok(Vec::new())
700    }
701
702    fn compute_tensor_magnitude(&self, tensor: &Tensor) -> Result<f32> {
703        match tensor {
704            Tensor::F32(arr) => {
705                let magnitude = arr.iter().map(|&x| x * x).sum::<f32>().sqrt();
706                Ok(magnitude)
707            },
708            _ => Err(TrustformersError::new(
709                crate::errors::ErrorKind::TensorOpError {
710                    operation: "compute_magnitude".to_string(),
711                    reason: "Magnitude computation not supported for this tensor type".to_string(),
712                },
713            )),
714        }
715    }
716
717    fn create_node_debug_info(&self, node: &GraphNode) -> NodeDebugInfo {
718        let gradient_magnitude =
719            node.gradient.as_ref().and_then(|grad| self.compute_tensor_magnitude(grad).ok());
720
721        let tensor_magnitude = self.compute_tensor_magnitude(&node.value).unwrap_or(0.0);
722
723        NodeDebugInfo {
724            id: node.id,
725            name: node.name.clone(),
726            operation: node.operation.clone(),
727            shape: node.shape.clone(),
728            requires_grad: node.requires_grad,
729            is_leaf: node.is_leaf,
730            has_gradient: node.gradient.is_some(),
731            gradient_magnitude,
732            tensor_magnitude,
733            memory_usage: node.value.memory_usage(),
734            parents: node.parents.clone(),
735            children: node.children.clone(),
736            depth_from_root: 0, // Would be computed in real implementation
737        }
738    }
739
740    fn generate_dot_graph(&self, graph: &ComputationGraph) -> Result<String> {
741        let mut dot = String::new();
742        writeln!(dot, "digraph ComputationGraph {{")?;
743        writeln!(dot, "  rankdir=TB;")?;
744        writeln!(dot, "  node [shape=box, style=filled, fontname=Arial];")?;
745
746        let nodes = self.get_all_nodes(graph)?;
747
748        for node in &nodes {
749            let color = if node.is_leaf {
750                "lightblue"
751            } else if node.gradient.is_some() {
752                "lightgreen"
753            } else {
754                "lightgray"
755            };
756
757            let label = if let Some(ref name) = node.name {
758                format!(
759                    "{}\\n{:?}",
760                    name,
761                    node.operation.as_ref().unwrap_or(&OperationType::Add)
762                )
763            } else {
764                format!(
765                    "Node {}\\n{:?}",
766                    node.id,
767                    node.operation.as_ref().unwrap_or(&OperationType::Add)
768                )
769            };
770
771            writeln!(
772                dot,
773                "  {} [label=\"{}\", fillcolor={}];",
774                node.id, label, color
775            )?;
776        }
777
778        for node in &nodes {
779            for &child_id in &node.children {
780                writeln!(dot, "  {} -> {};", node.id, child_id)?;
781            }
782        }
783
784        writeln!(dot, "}}")?;
785        Ok(dot)
786    }
787
788    fn generate_ascii_graph(&self, graph: &ComputationGraph) -> Result<String> {
789        let mut output = String::new();
790        writeln!(output, "Computation Graph (ASCII)")?;
791        writeln!(output, "=========================")?;
792
793        let nodes = self.get_all_nodes(graph)?;
794
795        for node in &nodes {
796            let status = if node.is_leaf { "[LEAF]" } else { "[OP]" };
797            let grad_status = if node.gradient.is_some() { "[GRAD]" } else { "" };
798
799            writeln!(
800                output,
801                "Node {}: {} {} {:?}",
802                node.id,
803                status,
804                grad_status,
805                node.operation.as_ref().unwrap_or(&OperationType::Add)
806            )?;
807
808            if !node.children.is_empty() {
809                writeln!(output, "  └─ Children: {:?}", node.children)?;
810            }
811        }
812
813        Ok(output)
814    }
815
816    fn generate_json_graph(&self, graph: &ComputationGraph) -> Result<String> {
817        let nodes = self.get_all_nodes(graph)?;
818        let debug_nodes: Vec<NodeDebugInfo> =
819            nodes.iter().map(|node| self.create_node_debug_info(node)).collect();
820
821        let json_data = serde_json::json!({
822            "nodes": debug_nodes,
823            "total_nodes": nodes.len(),
824        });
825
826        Ok(serde_json::to_string_pretty(&json_data)?)
827    }
828
829    fn generate_html_graph(&self, graph: &ComputationGraph) -> Result<String> {
830        let mut html = String::new();
831
832        html.push_str("<!DOCTYPE html>\n<html>\n<head>\n");
833        html.push_str("<title>Computation Graph Debug View</title>\n");
834        html.push_str("<style>\n");
835        html.push_str("body { font-family: Arial, sans-serif; margin: 20px; }\n");
836        html.push_str(
837            ".node { border: 1px solid #ccc; margin: 10px; padding: 10px; border-radius: 5px; }\n",
838        );
839        html.push_str(".leaf { background-color: #e3f2fd; }\n");
840        html.push_str(".op { background-color: #f3e5f5; }\n");
841        html.push_str(".grad { border-left: 4px solid #4caf50; }\n");
842        html.push_str("</style>\n");
843        html.push_str("</head>\n<body>\n");
844
845        html.push_str("<h1>Computation Graph Debug View</h1>\n");
846
847        let nodes = self.get_all_nodes(graph)?;
848
849        for node in &nodes {
850            let node_class = if node.is_leaf { "node leaf" } else { "node op" };
851            let grad_class = if node.gradient.is_some() { " grad" } else { "" };
852
853            html.push_str(&format!("<div class=\"{}{}\">\n", node_class, grad_class));
854            html.push_str(&format!("<h3>Node {}</h3>\n", node.id));
855
856            if let Some(ref name) = node.name {
857                html.push_str(&format!("<p><strong>Name:</strong> {}</p>\n", name));
858            }
859
860            if let Some(ref op) = node.operation {
861                html.push_str(&format!("<p><strong>Operation:</strong> {:?}</p>\n", op));
862            }
863
864            html.push_str(&format!(
865                "<p><strong>Shape:</strong> {:?}</p>\n",
866                node.shape
867            ));
868            html.push_str(&format!(
869                "<p><strong>Requires Grad:</strong> {}</p>\n",
870                node.requires_grad
871            ));
872            html.push_str(&format!(
873                "<p><strong>Is Leaf:</strong> {}</p>\n",
874                node.is_leaf
875            ));
876            html.push_str(&format!(
877                "<p><strong>Has Gradient:</strong> {}</p>\n",
878                node.gradient.is_some()
879            ));
880            html.push_str(&format!(
881                "<p><strong>Memory:</strong> {} bytes</p>\n",
882                node.value.memory_usage()
883            ));
884
885            html.push_str("</div>\n");
886        }
887
888        html.push_str("</body>\n</html>\n");
889        Ok(html)
890    }
891}
892
893// From<std::fmt::Error> implementation is already provided in error.rs
894
895#[cfg(test)]
896mod tests {
897    use super::*;
898
899    #[test]
900    fn test_debugger_creation() {
901        let debugger = GraphDebugger::new();
902        assert_eq!(debugger.config.max_display_nodes, 50);
903        assert_eq!(debugger.config.output_format, GraphOutputFormat::Dot);
904    }
905
906    #[test]
907    fn test_config_serialization() {
908        let config = DebuggerConfig::default();
909        let serialized = serde_json::to_string(&config).expect("JSON serialization failed");
910        let deserialized: DebuggerConfig =
911            serde_json::from_str(&serialized).expect("JSON deserialization failed");
912
913        assert_eq!(config.max_display_nodes, deserialized.max_display_nodes);
914        assert_eq!(config.show_gradients, deserialized.show_gradients);
915    }
916
917    #[test]
918    fn test_breakpoint_management() {
919        let mut debugger = GraphDebugger::new();
920
921        debugger.set_breakpoint(1);
922        debugger.set_breakpoint(2);
923
924        assert!(debugger.should_break(1));
925        assert!(debugger.should_break(2));
926        assert!(!debugger.should_break(3));
927
928        debugger.remove_breakpoint(1);
929        assert!(!debugger.should_break(1));
930        assert!(debugger.should_break(2));
931    }
932
933    #[test]
934    fn test_issue_severity() {
935        assert!(matches!(IssueSeverity::Critical, IssueSeverity::Critical));
936        assert!(matches!(IssueSeverity::Warning, IssueSeverity::Warning));
937        assert!(matches!(IssueSeverity::Info, IssueSeverity::Info));
938    }
939
940    #[test]
941    fn test_issue_types() {
942        let issue = GraphIssue {
943            issue_type: IssueType::VanishingGradients,
944            nodes: vec![1, 2, 3],
945            description: "Test issue".to_string(),
946            severity: IssueSeverity::Warning,
947            suggestion: Some("Test suggestion".to_string()),
948        };
949
950        assert_eq!(issue.issue_type, IssueType::VanishingGradients);
951        assert_eq!(issue.nodes.len(), 3);
952        assert!(issue.suggestion.is_some());
953    }
954
955    #[test]
956    fn test_output_formats() {
957        assert_eq!(GraphOutputFormat::Dot, GraphOutputFormat::Dot);
958        assert_ne!(GraphOutputFormat::Dot, GraphOutputFormat::ASCII);
959    }
960}