1use crate::graph::{Graph, TensorID};
32use crate::tensor::Tensor;
33use crate::{Context, Float};
34use std::collections::{HashMap, HashSet};
35use std::fmt::Write;
36
37fn collect_reachable_nodes<F: Float>(root_id: TensorID, graph: &Graph<F>) -> Vec<TensorID> {
44 let mut visited = HashSet::new();
45 let mut order = Vec::new();
46 collect_dfs(root_id, graph, &mut visited, &mut order);
47 order
48}
49
50fn collect_dfs<F: Float>(
51 node_id: TensorID,
52 graph: &Graph<F>,
53 visited: &mut HashSet<TensorID>,
54 order: &mut Vec<TensorID>,
55) {
56 if visited.contains(&node_id) {
57 return;
58 }
59 visited.insert(node_id);
60
61 let node = graph.access_inner(node_id);
62 let incoming = node.incoming_nodes.clone();
63 drop(node); for inc in &incoming {
66 collect_dfs(inc.id, graph, visited, order);
67 }
68
69 order.push(node_id);
70}
71
72struct NodeInfo {
74 id: TensorID,
75 op_name: String,
76 topo_rank: usize,
77 is_differentiable: bool,
78 is_placeholder: bool,
79 placeholder_name: Option<String>,
80 is_variable: bool,
81 num_inputs: usize,
82 input_ids: Vec<TensorID>,
83 known_shape: Option<Vec<isize>>,
84}
85
86fn extract_node_info<F: Float>(node_id: TensorID, graph: &Graph<F>) -> NodeInfo {
87 let node = graph.access_inner(node_id);
88 let op_name = node
89 .op
90 .as_ref()
91 .map(|o| {
92 let full = o.name();
93 full.rsplit("::").next().unwrap_or(full).to_string()
95 })
96 .unwrap_or_else(|| "Source".to_string());
97
98 let input_ids: Vec<TensorID> = node.incoming_nodes.iter().map(|inc| inc.id).collect();
99 let num_inputs = input_ids.len();
100
101 let known_shape = node.knownshape.as_ref().map(|ks| ks.get().to_vec());
102
103 NodeInfo {
104 id: node_id,
105 op_name,
106 topo_rank: node.topo_rank,
107 is_differentiable: node.is_differentiable,
108 is_placeholder: node.placeholder_name.is_some(),
109 placeholder_name: node.placeholder_name.map(|s| s.to_string()),
110 is_variable: node.variable_id.is_some(),
111 num_inputs,
112 input_ids,
113 known_shape,
114 }
115}
116
117pub fn graph_to_dot<'g, F: Float>(root: &Tensor<'g, F>, ctx: &'g Context<'g, F>) -> String {
141 let graph = get_graph(ctx);
142 let nodes = collect_reachable_nodes(root.id(), graph);
143 let node_set: HashSet<TensorID> = nodes.iter().copied().collect();
144
145 let mut output = String::new();
146 let _ = writeln!(output, "digraph computation_graph {{");
147 let _ = writeln!(output, " rankdir=BT;");
148 let _ = writeln!(
149 output,
150 " node [shape=box, style=\"rounded,filled\", fontname=\"Helvetica\"];"
151 );
152 let _ = writeln!(output, " edge [color=gray50];");
153 let _ = writeln!(output);
154
155 for &nid in &nodes {
157 let info = extract_node_info(nid, graph);
158 let label = node_label(&info);
159 let style = node_color(&info);
160 let _ = writeln!(output, " n{nid} [label=\"{label}\", {style}];");
161 }
162
163 let _ = writeln!(output);
164
165 for &nid in &nodes {
167 let info = extract_node_info(nid, graph);
168 for &src in &info.input_ids {
169 if node_set.contains(&src) {
170 let _ = writeln!(output, " n{src} -> n{nid};");
171 }
172 }
173 }
174
175 let _ = writeln!(output, "}}");
176 output
177}
178
179fn node_label(info: &NodeInfo) -> String {
180 let mut label = String::new();
181
182 if let Some(ref name) = info.placeholder_name {
183 let _ = write!(label, "{name}\\n");
184 }
185
186 let _ = write!(label, "{}", info.op_name);
187
188 if let Some(ref shape) = info.known_shape {
189 let _ = write!(label, "\\n{shape:?}");
190 }
191
192 let _ = write!(label, "\\n(id={})", info.id);
193 label
194}
195
196fn node_color(info: &NodeInfo) -> String {
197 if info.is_placeholder {
198 "fillcolor=\"#d5f5d5\"".to_string() } else if info.is_variable {
200 "fillcolor=\"#fff8d5\"".to_string() } else if info.is_differentiable {
202 "fillcolor=\"#d5e8f5\"".to_string() } else {
204 "fillcolor=\"#e8e8e8\"".to_string() }
206}
207
208pub fn graph_summary<'g, F: Float>(root: &Tensor<'g, F>, ctx: &'g Context<'g, F>) -> String {
218 let graph = get_graph(ctx);
219 let nodes = collect_reachable_nodes(root.id(), graph);
220
221 let mut output = String::new();
222 let _ = writeln!(output, "Computation Graph Summary");
223 let _ = writeln!(output, "=========================");
224 let _ = writeln!(output, "Total nodes: {}", nodes.len());
225
226 let mut placeholders = 0usize;
227 let mut variables = 0usize;
228 let mut ops = 0usize;
229 let mut max_rank = 0usize;
230 let mut op_counts: HashMap<String, usize> = HashMap::new();
231
232 for &nid in &nodes {
233 let info = extract_node_info(nid, graph);
234 if info.is_placeholder {
235 placeholders += 1;
236 } else if info.is_variable {
237 variables += 1;
238 } else {
239 ops += 1;
240 }
241 if info.topo_rank > max_rank {
242 max_rank = info.topo_rank;
243 }
244 *op_counts.entry(info.op_name.clone()).or_insert(0) += 1;
245 }
246
247 let _ = writeln!(output, " Placeholders: {placeholders}");
248 let _ = writeln!(output, " Variables: {variables}");
249 let _ = writeln!(output, " Operations: {ops}");
250 let _ = writeln!(output, " Max depth (topo rank): {max_rank}");
251 let _ = writeln!(output);
252
253 let mut sorted_ops: Vec<_> = op_counts.into_iter().collect();
255 sorted_ops.sort_by_key(|item| std::cmp::Reverse(item.1));
256
257 let _ = writeln!(output, "Operation breakdown:");
258 for (name, count) in &sorted_ops {
259 let _ = writeln!(output, " {name}: {count}");
260 }
261
262 output
263}
264
265pub fn graph_to_json<'g, F: Float>(root: &Tensor<'g, F>, ctx: &'g Context<'g, F>) -> String {
279 let graph = get_graph(ctx);
280 let nodes = collect_reachable_nodes(root.id(), graph);
281 let node_set: HashSet<TensorID> = nodes.iter().copied().collect();
282
283 let mut output = String::new();
284 let _ = writeln!(output, "{{");
285 let _ = writeln!(output, " \"nodes\": [");
286
287 for (idx, &nid) in nodes.iter().enumerate() {
288 let info = extract_node_info(nid, graph);
289 let comma = if idx + 1 < nodes.len() { "," } else { "" };
290 let shape_str = info
291 .known_shape
292 .as_ref()
293 .map(|s| format!("{s:?}"))
294 .unwrap_or_else(|| "null".to_string());
295 let _ = writeln!(
296 output,
297 " {{\"id\": {}, \"op\": \"{}\", \"rank\": {}, \"differentiable\": {}, \"shape\": {}}}{}",
298 info.id, info.op_name, info.topo_rank, info.is_differentiable, shape_str, comma
299 );
300 }
301
302 let _ = writeln!(output, " ],");
303 let _ = writeln!(output, " \"edges\": [");
304
305 let mut edge_idx = 0usize;
306 let total_edges: usize = nodes
307 .iter()
308 .map(|&nid| {
309 let info = extract_node_info(nid, graph);
310 info.input_ids
311 .iter()
312 .filter(|id| node_set.contains(id))
313 .count()
314 })
315 .sum();
316
317 for &nid in &nodes {
318 let info = extract_node_info(nid, graph);
319 for &src in &info.input_ids {
320 if node_set.contains(&src) {
321 edge_idx += 1;
322 let comma = if edge_idx < total_edges { "," } else { "" };
323 let _ = writeln!(
324 output,
325 " {{\"from\": {src}, \"to\": {}}}{comma}",
326 info.id
327 );
328 }
329 }
330 }
331
332 let _ = writeln!(output, " ]");
333 let _ = writeln!(output, "}}");
334 output
335}
336
337pub fn graph_to_mermaid<'g, F: Float>(root: &Tensor<'g, F>, ctx: &'g Context<'g, F>) -> String {
345 let graph = get_graph(ctx);
346 let nodes = collect_reachable_nodes(root.id(), graph);
347 let node_set: HashSet<TensorID> = nodes.iter().copied().collect();
348
349 let mut output = String::new();
350 let _ = writeln!(output, "graph BT");
351
352 for &nid in &nodes {
353 let info = extract_node_info(nid, graph);
354 let label = if let Some(ref name) = info.placeholder_name {
355 format!("{name}: {}", info.op_name)
356 } else {
357 info.op_name.clone()
358 };
359 let _ = writeln!(output, " N{nid}[\"{label}\"]");
360 }
361
362 for &nid in &nodes {
363 let info = extract_node_info(nid, graph);
364 for &src in &info.input_ids {
365 if node_set.contains(&src) {
366 let _ = writeln!(output, " N{src} --> N{nid}");
367 }
368 }
369 }
370
371 output
372}
373
374#[derive(Debug, Clone)]
380pub struct GraphStats {
381 pub total_nodes: usize,
383 pub num_placeholders: usize,
385 pub num_variables: usize,
387 pub num_operations: usize,
389 pub max_depth: usize,
391 pub num_edges: usize,
393 pub num_differentiable: usize,
395 pub op_breakdown: Vec<(String, usize)>,
397 pub max_fan_in: usize,
399 pub max_fan_out: usize,
401}
402
403impl GraphStats {
404 pub fn from_tensor<'g, F: Float>(root: &Tensor<'g, F>, ctx: &'g Context<'g, F>) -> Self {
406 let graph = get_graph(ctx);
407 let nodes = collect_reachable_nodes(root.id(), graph);
408
409 let mut num_placeholders = 0usize;
410 let mut num_variables = 0usize;
411 let mut num_operations = 0usize;
412 let mut num_differentiable = 0usize;
413 let mut max_depth = 0usize;
414 let mut num_edges = 0usize;
415 let mut max_fan_in = 0usize;
416 let mut op_counts: HashMap<String, usize> = HashMap::new();
417 let mut fan_out: HashMap<TensorID, usize> = HashMap::new();
418
419 for &nid in &nodes {
420 let info = extract_node_info(nid, graph);
421
422 if info.is_placeholder {
423 num_placeholders += 1;
424 } else if info.is_variable {
425 num_variables += 1;
426 } else {
427 num_operations += 1;
428 }
429
430 if info.is_differentiable {
431 num_differentiable += 1;
432 }
433
434 if info.topo_rank > max_depth {
435 max_depth = info.topo_rank;
436 }
437
438 num_edges += info.num_inputs;
439
440 if info.num_inputs > max_fan_in {
441 max_fan_in = info.num_inputs;
442 }
443
444 for &src in &info.input_ids {
445 *fan_out.entry(src).or_insert(0) += 1;
446 }
447
448 *op_counts.entry(info.op_name).or_insert(0) += 1;
449 }
450
451 let max_fan_out = fan_out.values().copied().max().unwrap_or(0);
452
453 let mut op_breakdown: Vec<_> = op_counts.into_iter().collect();
454 op_breakdown.sort_by_key(|item| std::cmp::Reverse(item.1));
455
456 GraphStats {
457 total_nodes: nodes.len(),
458 num_placeholders,
459 num_variables,
460 num_operations,
461 max_depth,
462 num_edges,
463 num_differentiable,
464 op_breakdown,
465 max_fan_in,
466 max_fan_out,
467 }
468 }
469
470 pub fn display(&self) -> String {
472 let mut output = String::new();
473 let _ = writeln!(output, "Graph Statistics");
474 let _ = writeln!(output, "================");
475 let _ = writeln!(output, "Total nodes: {}", self.total_nodes);
476 let _ = writeln!(output, "Placeholders: {}", self.num_placeholders);
477 let _ = writeln!(output, "Variables: {}", self.num_variables);
478 let _ = writeln!(output, "Operations: {}", self.num_operations);
479 let _ = writeln!(output, "Edges: {}", self.num_edges);
480 let _ = writeln!(output, "Max depth: {}", self.max_depth);
481 let _ = writeln!(output, "Differentiable: {}", self.num_differentiable);
482 let _ = writeln!(output, "Max fan-in: {}", self.max_fan_in);
483 let _ = writeln!(output, "Max fan-out: {}", self.max_fan_out);
484 let _ = writeln!(output);
485 let _ = writeln!(output, "Operation breakdown:");
486 for (name, count) in &self.op_breakdown {
487 let _ = writeln!(output, " {name}: {count}");
488 }
489 output
490 }
491}
492
493#[derive(Debug, Clone)]
499pub struct VisualizationConfig {
500 pub show_shapes: bool,
502 pub show_operations: bool,
504 pub show_gradients: bool,
506 pub max_nodes: Option<usize>,
508 pub format: OutputFormat,
510 pub show_values: bool,
512}
513
514impl Default for VisualizationConfig {
515 fn default() -> Self {
516 Self {
517 show_shapes: true,
518 show_operations: true,
519 show_gradients: false,
520 max_nodes: Some(100),
521 format: OutputFormat::Dot,
522 show_values: false,
523 }
524 }
525}
526
527#[derive(Debug, Clone, Copy)]
529pub enum OutputFormat {
530 Dot,
532 Text,
534 Json,
536 Mermaid,
538}
539
540#[derive(Debug, thiserror::Error)]
542pub enum VisualizationError {
543 #[error("Graph traversal error: {0}")]
544 GraphTraversal(String),
545 #[error("Format error: {0}")]
546 Format(#[from] std::fmt::Error),
547 #[error("IO error: {0}")]
548 Io(#[from] std::io::Error),
549 #[error("Invalid configuration: {0}")]
550 Config(String),
551}
552
553fn get_graph<'g, F: Float>(ctx: &'g Context<'g, F>) -> &'g Graph<F> {
559 use std::ops::Deref;
560 ctx.deref()
561}
562
563#[cfg(test)]
568mod tests {
569 use super::*;
570 use crate::tensor_ops;
571
572 #[test]
573 fn test_graph_to_dot_basic() {
574 crate::run(|ctx: &mut crate::Context<f64>| {
575 let x = ctx.placeholder("x", &[3]);
576 let y = x * 2.0;
577 let loss = crate::tensor_ops::reduction::sum_all(y);
578
579 let dot = graph_to_dot(&loss, ctx);
580 assert!(dot.contains("digraph computation_graph"));
581 assert!(dot.contains("->"));
582 assert!(dot.contains("}"));
583 });
584 }
585
586 #[test]
587 fn test_graph_to_dot_multi_input() {
588 crate::run(|ctx: &mut crate::Context<f64>| {
589 let x = ctx.placeholder("x", &[2]);
590 let y = ctx.placeholder("y", &[2]);
591 let z = x + y;
592 let loss = crate::tensor_ops::reduction::sum_all(z);
593
594 let dot = graph_to_dot(&loss, ctx);
595 assert!(dot.contains("digraph computation_graph"));
596 assert!(dot.contains("fillcolor"));
598 });
599 }
600
601 #[test]
602 fn test_graph_summary_basic() {
603 crate::run(|ctx: &mut crate::Context<f64>| {
604 let x = ctx.placeholder("x", &[3]);
605 let y = x * 2.0 + 1.0;
606 let loss = crate::tensor_ops::reduction::sum_all(y);
607
608 let summary = graph_summary(&loss, ctx);
609 assert!(summary.contains("Computation Graph Summary"));
610 assert!(summary.contains("Total nodes:"));
611 assert!(summary.contains("Placeholders:"));
612 });
613 }
614
615 #[test]
616 fn test_graph_to_json() {
617 crate::run(|ctx: &mut crate::Context<f64>| {
618 let x = ctx.placeholder("x", &[2]);
619 let y = x * 3.0;
620
621 let json = graph_to_json(&y, ctx);
622 assert!(json.contains("\"nodes\""));
623 assert!(json.contains("\"edges\""));
624 assert!(json.contains("\"op\""));
625 });
626 }
627
628 #[test]
629 fn test_graph_to_mermaid() {
630 crate::run(|ctx: &mut crate::Context<f64>| {
631 let x = ctx.placeholder("x", &[2]);
632 let y = x * 2.0;
633
634 let mermaid = graph_to_mermaid(&y, ctx);
635 assert!(mermaid.contains("graph BT"));
636 assert!(mermaid.contains("-->"));
637 });
638 }
639
640 #[test]
641 fn test_graph_stats() {
642 crate::run(|ctx: &mut crate::Context<f64>| {
643 let x = ctx.placeholder("x", &[3]);
644 let y = x * 2.0 + 1.0;
645 let loss = crate::tensor_ops::reduction::sum_all(y);
646
647 let stats = GraphStats::from_tensor(&loss, ctx);
648 assert!(stats.total_nodes > 0);
649 assert!(stats.num_placeholders >= 1);
650 assert!(stats.num_edges > 0);
651 assert!(stats.max_depth > 0);
652 assert!(!stats.op_breakdown.is_empty());
653 });
654 }
655
656 #[test]
657 fn test_graph_stats_display() {
658 crate::run(|ctx: &mut crate::Context<f64>| {
659 let x = ctx.placeholder("x", &[3]);
660 let loss = crate::tensor_ops::reduction::sum_all(x * x);
661
662 let stats = GraphStats::from_tensor(&loss, ctx);
663 let display = stats.display();
664 assert!(display.contains("Graph Statistics"));
665 assert!(display.contains("Total nodes:"));
666 assert!(display.contains("Max fan-in:"));
667 assert!(display.contains("Max fan-out:"));
668 });
669 }
670
671 #[test]
672 fn test_graph_dot_colors() {
673 crate::run(|ctx: &mut crate::Context<f64>| {
674 let x = ctx.placeholder("x", &[2]);
675 let y = x * 2.0;
676
677 let dot = graph_to_dot(&y, ctx);
678 assert!(dot.contains("fillcolor"));
680 });
681 }
682
683 #[test]
684 fn test_visualization_config_default() {
685 let config = VisualizationConfig::default();
686 assert!(config.show_shapes);
687 assert!(config.show_operations);
688 assert!(!config.show_gradients);
689 assert_eq!(config.max_nodes, Some(100));
690 assert!(matches!(config.format, OutputFormat::Dot));
691 }
692
693 #[test]
694 fn test_graph_stats_single_node() {
695 crate::run(|ctx: &mut crate::Context<f64>| {
696 let x = ctx.placeholder("x", &[]);
697
698 let stats = GraphStats::from_tensor(&x, ctx);
699 assert_eq!(stats.total_nodes, 1);
700 assert_eq!(stats.num_placeholders, 1);
701 assert_eq!(stats.num_operations, 0);
702 assert_eq!(stats.num_edges, 0);
703 });
704 }
705
706 #[test]
707 fn test_collect_reachable_shared_nodes() {
708 crate::run(|ctx: &mut crate::Context<f64>| {
709 let x = ctx.placeholder("x", &[2]);
710 let y = x + x;
712 let loss = crate::tensor_ops::reduction::sum_all(y);
713
714 let graph: &Graph<f64> = std::ops::Deref::deref(ctx);
715 let nodes = collect_reachable_nodes(loss.id(), graph);
716 let x_count = nodes.iter().filter(|&&id| id == x.id()).count();
718 assert_eq!(x_count, 1);
719 });
720 }
721}