1#![allow(unused_variables)] use super::graph::{ComputationGraph, GraphNode, NodeId, OperationType};
10use crate::errors::{Result, TrustformersError};
11use crate::tensor::Tensor;
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, HashSet, VecDeque};
14use std::fmt::Write;
15
16pub struct GraphDebugger {
18 config: DebuggerConfig,
20 analysis_cache: HashMap<String, AnalysisResult>,
22 breakpoints: HashSet<NodeId>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct DebuggerConfig {
29 pub max_display_nodes: usize,
31 pub show_gradients: bool,
33 pub show_shapes: bool,
35 pub show_values: bool,
37 pub output_format: GraphOutputFormat,
39 pub gradient_magnitude_threshold: f32,
41 pub vanishing_gradient_threshold: f32,
43 pub exploding_gradient_threshold: f32,
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
49pub enum GraphOutputFormat {
50 Dot,
52 ASCII,
54 JSON,
56 HTML,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct AnalysisResult {
63 pub total_nodes: usize,
65 pub leaf_nodes: usize,
67 pub root_nodes: usize,
69 pub max_depth: usize,
71 pub operation_counts: HashMap<String, usize>,
73 pub gradient_stats: GradientFlowStats,
75 pub memory_stats: MemoryStats,
77 pub issues: Vec<GraphIssue>,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct GradientFlowStats {
84 pub nodes_with_gradients: usize,
86 pub nodes_requiring_gradients: usize,
88 pub average_gradient_magnitude: f32,
90 pub max_gradient_magnitude: f32,
92 pub min_gradient_magnitude: f32,
94 pub vanishing_gradient_nodes: Vec<NodeId>,
96 pub exploding_gradient_nodes: Vec<NodeId>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct MemoryStats {
103 pub total_tensor_memory: usize,
105 pub total_gradient_memory: usize,
107 pub peak_memory_estimate: usize,
109 pub memory_per_node: HashMap<NodeId, usize>,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct GraphIssue {
116 pub issue_type: IssueType,
118 pub nodes: Vec<NodeId>,
120 pub description: String,
122 pub severity: IssueSeverity,
124 pub suggestion: Option<String>,
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
130pub enum IssueType {
131 VanishingGradients,
133 ExplodingGradients,
135 DisconnectedSubgraph,
137 CyclicDependency,
139 IneffientOperation,
141 ShapeMismatch,
143 MemoryIssue,
145 NumericalInstability,
147}
148
149#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
151pub enum IssueSeverity {
152 Critical,
154 Warning,
156 Info,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct NodeDebugInfo {
163 pub id: NodeId,
164 pub name: Option<String>,
165 pub operation: Option<OperationType>,
166 pub shape: Vec<usize>,
167 pub requires_grad: bool,
168 pub is_leaf: bool,
169 pub has_gradient: bool,
170 pub gradient_magnitude: Option<f32>,
171 pub tensor_magnitude: f32,
172 pub memory_usage: usize,
173 pub parents: Vec<NodeId>,
174 pub children: Vec<NodeId>,
175 pub depth_from_root: usize,
176}
177
178#[derive(Debug, Clone)]
180pub struct TraversalInfo {
181 pub visited_nodes: HashSet<NodeId>,
182 pub node_depths: HashMap<NodeId, usize>,
183 pub execution_order: Vec<NodeId>,
184}
185
186impl Default for DebuggerConfig {
187 fn default() -> Self {
188 Self {
189 max_display_nodes: 50,
190 show_gradients: true,
191 show_shapes: true,
192 show_values: false,
193 output_format: GraphOutputFormat::Dot,
194 gradient_magnitude_threshold: 1e-6,
195 vanishing_gradient_threshold: 1e-7,
196 exploding_gradient_threshold: 1e3,
197 }
198 }
199}
200
201impl Default for GraphDebugger {
202 fn default() -> Self {
203 Self::new()
204 }
205}
206
207impl GraphDebugger {
208 pub fn new() -> Self {
210 Self {
211 config: DebuggerConfig::default(),
212 analysis_cache: HashMap::new(),
213 breakpoints: HashSet::new(),
214 }
215 }
216
217 pub fn with_config(config: DebuggerConfig) -> Self {
219 Self {
220 config,
221 analysis_cache: HashMap::new(),
222 breakpoints: HashSet::new(),
223 }
224 }
225
226 pub fn analyze(&mut self, graph: &ComputationGraph) -> Result<AnalysisResult> {
228 let graph_hash = self.compute_graph_hash(graph);
229
230 if let Some(cached_result) = self.analysis_cache.get(&graph_hash) {
231 return Ok(cached_result.clone());
232 }
233
234 let nodes = self.get_all_nodes(graph)?;
235 let total_nodes = nodes.len();
236
237 let leaf_nodes = nodes.iter().filter(|n| n.is_leaf).count();
239 let root_nodes = nodes.iter().filter(|n| n.parents.is_empty()).count();
240
241 let max_depth = self.compute_max_depth(graph, &nodes)?;
243
244 let operation_counts = self.count_operations(&nodes);
246
247 let gradient_stats = self.analyze_gradient_flow(&nodes)?;
249
250 let memory_stats = self.compute_memory_stats(&nodes)?;
252
253 let issues = self.detect_issues(graph, &nodes, &gradient_stats)?;
255
256 let result = AnalysisResult {
257 total_nodes,
258 leaf_nodes,
259 root_nodes,
260 max_depth,
261 operation_counts,
262 gradient_stats,
263 memory_stats,
264 issues,
265 };
266
267 self.analysis_cache.insert(graph_hash, result.clone());
268 Ok(result)
269 }
270
271 pub fn visualize(&self, graph: &ComputationGraph) -> Result<String> {
273 match self.config.output_format {
274 GraphOutputFormat::Dot => self.generate_dot_graph(graph),
275 GraphOutputFormat::ASCII => self.generate_ascii_graph(graph),
276 GraphOutputFormat::JSON => self.generate_json_graph(graph),
277 GraphOutputFormat::HTML => self.generate_html_graph(graph),
278 }
279 }
280
281 pub fn trace_gradients(
283 &self,
284 graph: &ComputationGraph,
285 start_node: NodeId,
286 ) -> Result<Vec<NodeDebugInfo>> {
287 let mut trace = Vec::new();
288 let mut visited = HashSet::new();
289 let mut queue = VecDeque::new();
290
291 queue.push_back(start_node);
292
293 while let Some(node_id) = queue.pop_front() {
294 if visited.contains(&node_id) {
295 continue;
296 }
297 visited.insert(node_id);
298
299 let node = self.get_node(graph, node_id)?;
300 let debug_info = self.create_node_debug_info(&node);
301 trace.push(debug_info);
302
303 for &parent_id in &node.parents {
305 if !visited.contains(&parent_id) {
306 queue.push_back(parent_id);
307 }
308 }
309 }
310
311 Ok(trace)
312 }
313
314 pub fn set_breakpoint(&mut self, node_id: NodeId) {
316 self.breakpoints.insert(node_id);
317 }
318
319 pub fn remove_breakpoint(&mut self, node_id: NodeId) {
321 self.breakpoints.remove(&node_id);
322 }
323
324 pub fn should_break(&self, node_id: NodeId) -> bool {
326 self.breakpoints.contains(&node_id)
327 }
328
329 pub fn get_node_info(
331 &self,
332 graph: &ComputationGraph,
333 node_id: NodeId,
334 ) -> Result<NodeDebugInfo> {
335 let node = self.get_node(graph, node_id)?;
336 Ok(self.create_node_debug_info(&node))
337 }
338
339 pub fn find_nodes_by_name(
341 &self,
342 graph: &ComputationGraph,
343 pattern: &str,
344 ) -> Result<Vec<NodeId>> {
345 let nodes = self.get_all_nodes(graph)?;
346 let matching_nodes = nodes
347 .iter()
348 .filter(|node| node.name.as_ref().map(|name| name.contains(pattern)).unwrap_or(false))
349 .map(|node| node.id)
350 .collect();
351
352 Ok(matching_nodes)
353 }
354
355 pub fn generate_summary(&mut self, graph: &ComputationGraph) -> Result<String> {
357 let analysis = self.analyze(graph)?;
358 let mut report = String::new();
359
360 writeln!(report, "Computation Graph Summary")?;
361 writeln!(report, "=========================")?;
362 writeln!(report)?;
363
364 writeln!(report, "Graph Structure:")?;
365 writeln!(report, " Total nodes: {}", analysis.total_nodes)?;
366 writeln!(report, " Leaf nodes: {}", analysis.leaf_nodes)?;
367 writeln!(report, " Root nodes: {}", analysis.root_nodes)?;
368 writeln!(report, " Maximum depth: {}", analysis.max_depth)?;
369 writeln!(report)?;
370
371 writeln!(report, "Operations:")?;
372 for (op_type, count) in &analysis.operation_counts {
373 writeln!(report, " {}: {}", op_type, count)?;
374 }
375 writeln!(report)?;
376
377 writeln!(report, "Gradient Flow:")?;
378 writeln!(
379 report,
380 " Nodes with gradients: {}",
381 analysis.gradient_stats.nodes_with_gradients
382 )?;
383 writeln!(
384 report,
385 " Nodes requiring gradients: {}",
386 analysis.gradient_stats.nodes_requiring_gradients
387 )?;
388 writeln!(
389 report,
390 " Average gradient magnitude: {:.6}",
391 analysis.gradient_stats.average_gradient_magnitude
392 )?;
393 writeln!(
394 report,
395 " Max gradient magnitude: {:.6}",
396 analysis.gradient_stats.max_gradient_magnitude
397 )?;
398 writeln!(
399 report,
400 " Min gradient magnitude: {:.6}",
401 analysis.gradient_stats.min_gradient_magnitude
402 )?;
403 writeln!(report)?;
404
405 writeln!(report, "Memory Usage:")?;
406 writeln!(
407 report,
408 " Total tensor memory: {} bytes",
409 analysis.memory_stats.total_tensor_memory
410 )?;
411 writeln!(
412 report,
413 " Total gradient memory: {} bytes",
414 analysis.memory_stats.total_gradient_memory
415 )?;
416 writeln!(
417 report,
418 " Peak memory estimate: {} bytes",
419 analysis.memory_stats.peak_memory_estimate
420 )?;
421 writeln!(report)?;
422
423 if !analysis.issues.is_empty() {
424 writeln!(report, "Issues Detected:")?;
425 for issue in &analysis.issues {
426 writeln!(
427 report,
428 " [{:?}] {:?}: {}",
429 issue.severity, issue.issue_type, issue.description
430 )?;
431 if let Some(suggestion) = &issue.suggestion {
432 writeln!(report, " Suggestion: {}", suggestion)?;
433 }
434 }
435 } else {
436 writeln!(report, "No issues detected.")?;
437 }
438
439 Ok(report)
440 }
441
442 pub fn save_debug_info(&mut self, graph: &ComputationGraph, path: &str) -> Result<()> {
444 let analysis = self.analyze(graph)?;
445 let json_data = serde_json::to_string_pretty(&analysis)?;
446 std::fs::write(path, json_data)?;
447 Ok(())
448 }
449
450 fn get_all_nodes(&self, graph: &ComputationGraph) -> Result<Vec<GraphNode>> {
453 Ok(graph.export_graph().nodes)
455 }
456
457 fn get_node(&self, graph: &ComputationGraph, node_id: NodeId) -> Result<GraphNode> {
458 let export = graph.export_graph();
460 export.nodes.into_iter().find(|node| node.id == node_id).ok_or_else(|| {
461 TrustformersError::new(crate::errors::ErrorKind::TensorOpError {
462 operation: "get_node".to_string(),
463 reason: format!("Node {} not found in computation graph", node_id),
464 })
465 })
466 }
467
468 fn compute_graph_hash(&self, graph: &ComputationGraph) -> String {
469 use std::collections::hash_map::DefaultHasher;
471 use std::hash::{Hash, Hasher};
472
473 let mut hasher = DefaultHasher::new();
474
475 graph.num_nodes().hash(&mut hasher);
477
478 graph.get_topological_order().hash(&mut hasher);
480
481 let export = graph.export_graph();
483 let mut nodes = export.nodes;
484 nodes.sort_by_key(|node| node.id);
485
486 for node in nodes {
487 node.id.hash(&mut hasher);
488
489 if let Some(ref op) = node.operation {
491 std::mem::discriminant(op).hash(&mut hasher);
492 }
493
494 let mut parents = node.parents.clone();
496 parents.sort();
497 parents.hash(&mut hasher);
498
499 node.requires_grad.hash(&mut hasher);
501 node.is_leaf.hash(&mut hasher);
502 }
503
504 format!("graph_{:x}", hasher.finish())
505 }
506
507 fn compute_max_depth(&self, graph: &ComputationGraph, nodes: &[GraphNode]) -> Result<usize> {
508 let mut max_depth = 0;
509 let mut visited = HashSet::new();
510
511 for node in nodes {
512 if node.is_leaf {
513 let depth = self.compute_node_depth(graph, node.id, &mut visited)?;
514 max_depth = max_depth.max(depth);
515 }
516 }
517
518 Ok(max_depth)
519 }
520
521 fn compute_node_depth(
522 &self,
523 graph: &ComputationGraph,
524 node_id: NodeId,
525 visited: &mut HashSet<NodeId>,
526 ) -> Result<usize> {
527 if visited.contains(&node_id) {
528 return Ok(0); }
530 visited.insert(node_id);
531
532 let node = self.get_node(graph, node_id)?;
533 if node.children.is_empty() {
534 return Ok(0);
535 }
536
537 let mut max_child_depth = 0;
538 for &child_id in &node.children {
539 let child_depth = self.compute_node_depth(graph, child_id, visited)?;
540 max_child_depth = max_child_depth.max(child_depth);
541 }
542
543 Ok(max_child_depth + 1)
544 }
545
546 fn count_operations(&self, nodes: &[GraphNode]) -> HashMap<String, usize> {
547 let mut counts = HashMap::new();
548
549 for node in nodes {
550 if let Some(ref op) = node.operation {
551 let op_name = format!("{:?}", op);
552 *counts.entry(op_name).or_insert(0) += 1;
553 }
554 }
555
556 counts
557 }
558
559 fn analyze_gradient_flow(&self, nodes: &[GraphNode]) -> Result<GradientFlowStats> {
560 let nodes_with_gradients = nodes.iter().filter(|n| n.gradient.is_some()).count();
561 let nodes_requiring_gradients = nodes.iter().filter(|n| n.requires_grad).count();
562
563 let gradient_magnitudes: Vec<f32> = nodes
564 .iter()
565 .filter_map(|node| {
566 node.gradient.as_ref().and_then(|grad| self.compute_tensor_magnitude(grad).ok())
567 })
568 .collect();
569
570 let (average_gradient_magnitude, max_gradient_magnitude, min_gradient_magnitude) =
571 if gradient_magnitudes.is_empty() {
572 (0.0, 0.0, 0.0)
573 } else {
574 let sum: f32 = gradient_magnitudes.iter().sum();
575 let avg = sum / gradient_magnitudes.len() as f32;
576 let max = gradient_magnitudes.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
577 let min = gradient_magnitudes.iter().fold(f32::INFINITY, |a, &b| a.min(b));
578 (avg, max, min)
579 };
580
581 let vanishing_gradient_nodes: Vec<NodeId> = nodes
582 .iter()
583 .filter(|node| {
584 node.gradient
585 .as_ref()
586 .and_then(|grad| self.compute_tensor_magnitude(grad).ok())
587 .map(|mag| mag < self.config.vanishing_gradient_threshold)
588 .unwrap_or(false)
589 })
590 .map(|node| node.id)
591 .collect();
592
593 let exploding_gradient_nodes: Vec<NodeId> = nodes
594 .iter()
595 .filter(|node| {
596 node.gradient
597 .as_ref()
598 .and_then(|grad| self.compute_tensor_magnitude(grad).ok())
599 .map(|mag| mag > self.config.exploding_gradient_threshold)
600 .unwrap_or(false)
601 })
602 .map(|node| node.id)
603 .collect();
604
605 Ok(GradientFlowStats {
606 nodes_with_gradients,
607 nodes_requiring_gradients,
608 average_gradient_magnitude,
609 max_gradient_magnitude,
610 min_gradient_magnitude,
611 vanishing_gradient_nodes,
612 exploding_gradient_nodes,
613 })
614 }
615
616 fn compute_memory_stats(&self, nodes: &[GraphNode]) -> Result<MemoryStats> {
617 let mut total_tensor_memory = 0;
618 let mut total_gradient_memory = 0;
619 let mut memory_per_node = HashMap::new();
620
621 for node in nodes {
622 let tensor_memory = node.value.memory_usage();
623 let gradient_memory = node.gradient.as_ref().map(|g| g.memory_usage()).unwrap_or(0);
624
625 total_tensor_memory += tensor_memory;
626 total_gradient_memory += gradient_memory;
627 memory_per_node.insert(node.id, tensor_memory + gradient_memory);
628 }
629
630 let peak_memory_estimate = total_tensor_memory + total_gradient_memory;
631
632 Ok(MemoryStats {
633 total_tensor_memory,
634 total_gradient_memory,
635 peak_memory_estimate,
636 memory_per_node,
637 })
638 }
639
640 fn detect_issues(
641 &self,
642 graph: &ComputationGraph,
643 nodes: &[GraphNode],
644 gradient_stats: &GradientFlowStats,
645 ) -> Result<Vec<GraphIssue>> {
646 let mut issues = Vec::new();
647
648 if !gradient_stats.vanishing_gradient_nodes.is_empty() {
650 issues.push(GraphIssue {
651 issue_type: IssueType::VanishingGradients,
652 nodes: gradient_stats.vanishing_gradient_nodes.clone(),
653 description: format!(
654 "Detected {} nodes with vanishing gradients",
655 gradient_stats.vanishing_gradient_nodes.len()
656 ),
657 severity: IssueSeverity::Warning,
658 suggestion: Some(
659 "Consider using gradient clipping or adjusting learning rates".to_string(),
660 ),
661 });
662 }
663
664 if !gradient_stats.exploding_gradient_nodes.is_empty() {
666 issues.push(GraphIssue {
667 issue_type: IssueType::ExplodingGradients,
668 nodes: gradient_stats.exploding_gradient_nodes.clone(),
669 description: format!(
670 "Detected {} nodes with exploding gradients",
671 gradient_stats.exploding_gradient_nodes.len()
672 ),
673 severity: IssueSeverity::Critical,
674 suggestion: Some("Apply gradient clipping to prevent instability".to_string()),
675 });
676 }
677
678 let disconnected_nodes = self.find_disconnected_nodes(graph, nodes)?;
680 if !disconnected_nodes.is_empty() {
681 issues.push(GraphIssue {
682 issue_type: IssueType::DisconnectedSubgraph,
683 nodes: disconnected_nodes,
684 description: "Found disconnected nodes in the computation graph".to_string(),
685 severity: IssueSeverity::Warning,
686 suggestion: Some("Check that all variables are properly connected".to_string()),
687 });
688 }
689
690 Ok(issues)
691 }
692
693 fn find_disconnected_nodes(
694 &self,
695 graph: &ComputationGraph,
696 nodes: &[GraphNode],
697 ) -> Result<Vec<NodeId>> {
698 Ok(Vec::new())
700 }
701
702 fn compute_tensor_magnitude(&self, tensor: &Tensor) -> Result<f32> {
703 match tensor {
704 Tensor::F32(arr) => {
705 let magnitude = arr.iter().map(|&x| x * x).sum::<f32>().sqrt();
706 Ok(magnitude)
707 },
708 _ => Err(TrustformersError::new(
709 crate::errors::ErrorKind::TensorOpError {
710 operation: "compute_magnitude".to_string(),
711 reason: "Magnitude computation not supported for this tensor type".to_string(),
712 },
713 )),
714 }
715 }
716
717 fn create_node_debug_info(&self, node: &GraphNode) -> NodeDebugInfo {
718 let gradient_magnitude =
719 node.gradient.as_ref().and_then(|grad| self.compute_tensor_magnitude(grad).ok());
720
721 let tensor_magnitude = self.compute_tensor_magnitude(&node.value).unwrap_or(0.0);
722
723 NodeDebugInfo {
724 id: node.id,
725 name: node.name.clone(),
726 operation: node.operation.clone(),
727 shape: node.shape.clone(),
728 requires_grad: node.requires_grad,
729 is_leaf: node.is_leaf,
730 has_gradient: node.gradient.is_some(),
731 gradient_magnitude,
732 tensor_magnitude,
733 memory_usage: node.value.memory_usage(),
734 parents: node.parents.clone(),
735 children: node.children.clone(),
736 depth_from_root: 0, }
738 }
739
740 fn generate_dot_graph(&self, graph: &ComputationGraph) -> Result<String> {
741 let mut dot = String::new();
742 writeln!(dot, "digraph ComputationGraph {{")?;
743 writeln!(dot, " rankdir=TB;")?;
744 writeln!(dot, " node [shape=box, style=filled, fontname=Arial];")?;
745
746 let nodes = self.get_all_nodes(graph)?;
747
748 for node in &nodes {
749 let color = if node.is_leaf {
750 "lightblue"
751 } else if node.gradient.is_some() {
752 "lightgreen"
753 } else {
754 "lightgray"
755 };
756
757 let label = if let Some(ref name) = node.name {
758 format!(
759 "{}\\n{:?}",
760 name,
761 node.operation.as_ref().unwrap_or(&OperationType::Add)
762 )
763 } else {
764 format!(
765 "Node {}\\n{:?}",
766 node.id,
767 node.operation.as_ref().unwrap_or(&OperationType::Add)
768 )
769 };
770
771 writeln!(
772 dot,
773 " {} [label=\"{}\", fillcolor={}];",
774 node.id, label, color
775 )?;
776 }
777
778 for node in &nodes {
779 for &child_id in &node.children {
780 writeln!(dot, " {} -> {};", node.id, child_id)?;
781 }
782 }
783
784 writeln!(dot, "}}")?;
785 Ok(dot)
786 }
787
788 fn generate_ascii_graph(&self, graph: &ComputationGraph) -> Result<String> {
789 let mut output = String::new();
790 writeln!(output, "Computation Graph (ASCII)")?;
791 writeln!(output, "=========================")?;
792
793 let nodes = self.get_all_nodes(graph)?;
794
795 for node in &nodes {
796 let status = if node.is_leaf { "[LEAF]" } else { "[OP]" };
797 let grad_status = if node.gradient.is_some() { "[GRAD]" } else { "" };
798
799 writeln!(
800 output,
801 "Node {}: {} {} {:?}",
802 node.id,
803 status,
804 grad_status,
805 node.operation.as_ref().unwrap_or(&OperationType::Add)
806 )?;
807
808 if !node.children.is_empty() {
809 writeln!(output, " └─ Children: {:?}", node.children)?;
810 }
811 }
812
813 Ok(output)
814 }
815
816 fn generate_json_graph(&self, graph: &ComputationGraph) -> Result<String> {
817 let nodes = self.get_all_nodes(graph)?;
818 let debug_nodes: Vec<NodeDebugInfo> =
819 nodes.iter().map(|node| self.create_node_debug_info(node)).collect();
820
821 let json_data = serde_json::json!({
822 "nodes": debug_nodes,
823 "total_nodes": nodes.len(),
824 });
825
826 Ok(serde_json::to_string_pretty(&json_data)?)
827 }
828
829 fn generate_html_graph(&self, graph: &ComputationGraph) -> Result<String> {
830 let mut html = String::new();
831
832 html.push_str("<!DOCTYPE html>\n<html>\n<head>\n");
833 html.push_str("<title>Computation Graph Debug View</title>\n");
834 html.push_str("<style>\n");
835 html.push_str("body { font-family: Arial, sans-serif; margin: 20px; }\n");
836 html.push_str(
837 ".node { border: 1px solid #ccc; margin: 10px; padding: 10px; border-radius: 5px; }\n",
838 );
839 html.push_str(".leaf { background-color: #e3f2fd; }\n");
840 html.push_str(".op { background-color: #f3e5f5; }\n");
841 html.push_str(".grad { border-left: 4px solid #4caf50; }\n");
842 html.push_str("</style>\n");
843 html.push_str("</head>\n<body>\n");
844
845 html.push_str("<h1>Computation Graph Debug View</h1>\n");
846
847 let nodes = self.get_all_nodes(graph)?;
848
849 for node in &nodes {
850 let node_class = if node.is_leaf { "node leaf" } else { "node op" };
851 let grad_class = if node.gradient.is_some() { " grad" } else { "" };
852
853 html.push_str(&format!("<div class=\"{}{}\">\n", node_class, grad_class));
854 html.push_str(&format!("<h3>Node {}</h3>\n", node.id));
855
856 if let Some(ref name) = node.name {
857 html.push_str(&format!("<p><strong>Name:</strong> {}</p>\n", name));
858 }
859
860 if let Some(ref op) = node.operation {
861 html.push_str(&format!("<p><strong>Operation:</strong> {:?}</p>\n", op));
862 }
863
864 html.push_str(&format!(
865 "<p><strong>Shape:</strong> {:?}</p>\n",
866 node.shape
867 ));
868 html.push_str(&format!(
869 "<p><strong>Requires Grad:</strong> {}</p>\n",
870 node.requires_grad
871 ));
872 html.push_str(&format!(
873 "<p><strong>Is Leaf:</strong> {}</p>\n",
874 node.is_leaf
875 ));
876 html.push_str(&format!(
877 "<p><strong>Has Gradient:</strong> {}</p>\n",
878 node.gradient.is_some()
879 ));
880 html.push_str(&format!(
881 "<p><strong>Memory:</strong> {} bytes</p>\n",
882 node.value.memory_usage()
883 ));
884
885 html.push_str("</div>\n");
886 }
887
888 html.push_str("</body>\n</html>\n");
889 Ok(html)
890 }
891}
892
893#[cfg(test)]
896mod tests {
897 use super::*;
898
899 #[test]
900 fn test_debugger_creation() {
901 let debugger = GraphDebugger::new();
902 assert_eq!(debugger.config.max_display_nodes, 50);
903 assert_eq!(debugger.config.output_format, GraphOutputFormat::Dot);
904 }
905
906 #[test]
907 fn test_config_serialization() {
908 let config = DebuggerConfig::default();
909 let serialized = serde_json::to_string(&config).expect("JSON serialization failed");
910 let deserialized: DebuggerConfig =
911 serde_json::from_str(&serialized).expect("JSON deserialization failed");
912
913 assert_eq!(config.max_display_nodes, deserialized.max_display_nodes);
914 assert_eq!(config.show_gradients, deserialized.show_gradients);
915 }
916
917 #[test]
918 fn test_breakpoint_management() {
919 let mut debugger = GraphDebugger::new();
920
921 debugger.set_breakpoint(1);
922 debugger.set_breakpoint(2);
923
924 assert!(debugger.should_break(1));
925 assert!(debugger.should_break(2));
926 assert!(!debugger.should_break(3));
927
928 debugger.remove_breakpoint(1);
929 assert!(!debugger.should_break(1));
930 assert!(debugger.should_break(2));
931 }
932
933 #[test]
934 fn test_issue_severity() {
935 assert!(matches!(IssueSeverity::Critical, IssueSeverity::Critical));
936 assert!(matches!(IssueSeverity::Warning, IssueSeverity::Warning));
937 assert!(matches!(IssueSeverity::Info, IssueSeverity::Info));
938 }
939
940 #[test]
941 fn test_issue_types() {
942 let issue = GraphIssue {
943 issue_type: IssueType::VanishingGradients,
944 nodes: vec![1, 2, 3],
945 description: "Test issue".to_string(),
946 severity: IssueSeverity::Warning,
947 suggestion: Some("Test suggestion".to_string()),
948 };
949
950 assert_eq!(issue.issue_type, IssueType::VanishingGradients);
951 assert_eq!(issue.nodes.len(), 3);
952 assert!(issue.suggestion.is_some());
953 }
954
955 #[test]
956 fn test_output_formats() {
957 assert_eq!(GraphOutputFormat::Dot, GraphOutputFormat::Dot);
958 assert_ne!(GraphOutputFormat::Dot, GraphOutputFormat::ASCII);
959 }
960}