1use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet, VecDeque};
10use std::fmt;
11use uuid::Uuid;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ComputationGraph {
16 pub id: Uuid,
18 pub nodes: HashMap<String, GraphNode>,
20 pub edges: HashMap<String, Vec<String>>,
22 pub root_nodes: HashSet<String>,
24 pub leaf_nodes: HashSet<String>,
26 pub metadata: GraphMetadata,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct GraphMetadata {
33 pub name: String,
35 pub node_count: usize,
37 pub edge_count: usize,
39 pub max_depth: usize,
41 pub estimated_memory_usage: u64,
43 pub estimated_flops: u64,
45 pub created_at: chrono::DateTime<chrono::Utc>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct GraphNode {
52 pub id: String,
54 pub name: String,
56 pub operation_type: OperationType,
58 pub input_shapes: Vec<Vec<usize>>,
60 pub output_shapes: Vec<Vec<usize>>,
62 pub flop_count: u64,
64 pub memory_usage: u64,
66 pub execution_time_us: Option<u64>,
68 pub parameter_count: Option<u64>,
70 pub topo_order: Option<usize>,
72 pub depth: usize,
74 pub metadata: HashMap<String, String>,
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
80pub enum OperationType {
81 Add,
83 Subtract,
84 Multiply,
85 Divide,
86 MatMul,
87 Dot,
88
89 ReLU,
91 Sigmoid,
92 Tanh,
93 GELU,
94 Softmax,
95
96 LayerNorm,
98 BatchNorm,
99 RMSNorm,
100
101 Conv1D,
103 Conv2D,
104 Conv3D,
105 ConvTranspose,
106
107 MaxPool,
109 AvgPool,
110 AdaptivePool,
111
112 Reshape,
114 Transpose,
115 Concat,
116 Split,
117 Slice,
118 Gather,
119 Scatter,
120
121 Sum,
123 Mean,
124 Max,
125 Min,
126
127 Attention,
129 MultiHeadAttention,
130 SelfAttention,
131 CrossAttention,
132
133 Embedding,
135 PositionalEmbedding,
136
137 CrossEntropyLoss,
139 MSELoss,
140 L1Loss,
141
142 If,
144 While,
145 Loop,
146
147 Custom(String),
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct GraphAnalysisConfig {
154 pub enable_memory_analysis: bool,
156 pub enable_flop_analysis: bool,
158 pub enable_optimization_analysis: bool,
160 pub enable_bottleneck_detection: bool,
162 pub enable_dataflow_analysis: bool,
164 pub bottleneck_threshold_us: u64,
166 pub large_memory_threshold: u64,
168}
169
170impl Default for GraphAnalysisConfig {
171 fn default() -> Self {
172 Self {
173 enable_memory_analysis: true,
174 enable_flop_analysis: true,
175 enable_optimization_analysis: true,
176 enable_bottleneck_detection: true,
177 enable_dataflow_analysis: true,
178 bottleneck_threshold_us: 1000, large_memory_threshold: 1024 * 1024 * 100, }
181 }
182}
183
184#[derive(Debug)]
186pub struct ComputationGraphAnalyzer {
187 config: GraphAnalysisConfig,
188 graphs: HashMap<Uuid, ComputationGraph>,
189 analysis_results: HashMap<Uuid, GraphAnalysisResult>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct GraphAnalysisResult {
195 pub graph_id: Uuid,
197 pub memory_analysis: Option<MemoryAnalysis>,
199 pub flop_analysis: Option<FlopAnalysis>,
201 pub optimization_opportunities: Vec<OptimizationOpportunity>,
203 pub bottleneck_analysis: Option<BottleneckAnalysis>,
205 pub dataflow_analysis: Option<DataFlowAnalysis>,
207 pub critical_path: Vec<String>,
209 pub statistics: GraphStatistics,
211 pub recommendations: Vec<String>,
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct MemoryAnalysis {
218 pub total_memory_usage: u64,
220 pub peak_memory_usage: u64,
222 pub memory_by_operation: HashMap<OperationType, u64>,
224 pub memory_hotspots: Vec<(String, u64)>,
226 pub fragmentation_ratio: f64,
228 pub optimization_suggestions: Vec<String>,
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
234pub struct FlopAnalysis {
235 pub total_flops: u64,
237 pub flops_by_operation: HashMap<OperationType, u64>,
239 pub compute_hotspots: Vec<(String, u64)>,
241 pub arithmetic_intensity: f64,
243 pub complexity_analysis: ComplexityAnalysis,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct ComplexityAnalysis {
250 pub time_complexity: String,
252 pub space_complexity: String,
254 pub parallelization_potential: f64,
256 pub sequential_dependencies: usize,
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct OptimizationOpportunity {
263 pub optimization_type: OptimizationType,
265 pub description: String,
267 pub affected_nodes: Vec<String>,
269 pub estimated_improvement: EstimatedImprovement,
271 pub implementation_difficulty: u8,
273 pub priority: OptimizationPriority,
275}
276
277#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
279pub enum OptimizationType {
280 OperationFusion,
282 RedundancyElimination,
284 MemoryLayoutOptimization,
286 AlgorithmicOptimization,
288 Parallelization,
290 DataAccessOptimization,
292 PrecisionOptimization,
294 Memoization,
296 ControlFlowOptimization,
298}
299
300#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
302pub enum OptimizationPriority {
303 Low,
304 Medium,
305 High,
306 Critical,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct EstimatedImprovement {
312 pub speedup_factor: f64,
314 pub memory_reduction: u64,
316 pub energy_savings: f64,
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize)]
322pub struct BottleneckAnalysis {
323 pub bottleneck_nodes: Vec<String>,
325 pub critical_path_nodes: Vec<String>,
327 pub critical_path_time_us: u64,
329 pub parallelizable_nodes: Vec<String>,
331 pub scheduling_suggestions: Vec<String>,
333}
334
335#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct DataFlowAnalysis {
338 pub data_dependencies: HashMap<String, Vec<String>>,
340 pub live_variables: HashMap<String, HashSet<String>>,
342 pub variable_lifetimes: HashMap<String, VariableLifetime>,
344 pub memory_reuse_opportunities: Vec<MemoryReuseOpportunity>,
346}
347
348#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct VariableLifetime {
351 pub birth_node: String,
353 pub death_node: String,
355 pub usage_nodes: Vec<String>,
357 pub memory_footprint: u64,
359}
360
361#[derive(Debug, Clone, Serialize, Deserialize)]
363pub struct MemoryReuseOpportunity {
364 pub reusable_variables: Vec<String>,
366 pub memory_savings: u64,
368 pub complexity: u8,
370}
371
372#[derive(Debug, Clone, Serialize, Deserialize)]
374pub struct GraphStatistics {
375 pub nodes_by_type: HashMap<OperationType, usize>,
377 pub average_fan_in: f64,
379 pub average_fan_out: f64,
381 pub diameter: usize,
383 pub clustering_coefficient: f64,
385 pub strongly_connected_components: usize,
387}
388
389impl ComputationGraphAnalyzer {
390 pub fn new(config: GraphAnalysisConfig) -> Self {
392 Self {
393 config,
394 graphs: HashMap::new(),
395 analysis_results: HashMap::new(),
396 }
397 }
398
399 pub fn add_graph(&mut self, graph: ComputationGraph) -> Result<()> {
401 let graph_id = graph.id;
402 self.graphs.insert(graph_id, graph);
403 Ok(())
404 }
405
406 pub fn create_graph(
408 &mut self,
409 name: String,
410 operations: Vec<(String, OperationType, Vec<String>)>, ) -> Result<Uuid> {
412 let graph_id = Uuid::new_v4();
413 let mut nodes = HashMap::new();
414 let mut edges = HashMap::new();
415 let mut root_nodes = HashSet::new();
416 let mut leaf_nodes = HashSet::new();
417
418 for (node_id, op_type, dependencies) in &operations {
420 let node = GraphNode {
421 id: node_id.clone(),
422 name: node_id.clone(),
423 operation_type: op_type.clone(),
424 input_shapes: vec![],
425 output_shapes: vec![],
426 flop_count: self.estimate_flops(op_type, &[]),
427 memory_usage: self.estimate_memory(op_type, &[]),
428 execution_time_us: None,
429 parameter_count: self.estimate_parameters(op_type),
430 topo_order: None,
431 depth: 0,
432 metadata: HashMap::new(),
433 };
434 nodes.insert(node_id.clone(), node);
435
436 if dependencies.is_empty() {
438 root_nodes.insert(node_id.clone());
439 }
440 edges.insert(node_id.clone(), dependencies.clone());
441 }
442
443 let all_dependencies: HashSet<String> = edges.values().flatten().cloned().collect();
445 for node_id in nodes.keys() {
446 if !all_dependencies.contains(node_id) {
447 leaf_nodes.insert(node_id.clone());
448 }
449 }
450
451 self.calculate_depth_and_topo_order(&mut nodes, &edges)?;
453
454 let metadata = GraphMetadata {
455 name,
456 node_count: nodes.len(),
457 edge_count: edges.values().map(|deps| deps.len()).sum(),
458 max_depth: nodes.values().map(|n| n.depth).max().unwrap_or(0),
459 estimated_memory_usage: nodes.values().map(|n| n.memory_usage).sum(),
460 estimated_flops: nodes.values().map(|n| n.flop_count).sum(),
461 created_at: chrono::Utc::now(),
462 };
463
464 let graph = ComputationGraph {
465 id: graph_id,
466 nodes,
467 edges,
468 root_nodes,
469 leaf_nodes,
470 metadata,
471 };
472
473 self.graphs.insert(graph_id, graph);
474 Ok(graph_id)
475 }
476
477 pub fn analyze_graph(&mut self, graph_id: Uuid) -> Result<GraphAnalysisResult> {
479 let graph = self
480 .graphs
481 .get(&graph_id)
482 .ok_or_else(|| anyhow::anyhow!("Graph not found: {}", graph_id))?;
483
484 let mut result = GraphAnalysisResult {
485 graph_id,
486 memory_analysis: None,
487 flop_analysis: None,
488 optimization_opportunities: Vec::new(),
489 bottleneck_analysis: None,
490 dataflow_analysis: None,
491 critical_path: Vec::new(),
492 statistics: self.calculate_statistics(graph)?,
493 recommendations: Vec::new(),
494 };
495
496 if self.config.enable_memory_analysis {
498 result.memory_analysis = Some(self.analyze_memory_usage(graph)?);
499 }
500
501 if self.config.enable_flop_analysis {
502 result.flop_analysis = Some(self.analyze_flop_usage(graph)?);
503 }
504
505 if self.config.enable_optimization_analysis {
506 result.optimization_opportunities = self.detect_optimization_opportunities(graph)?;
507 }
508
509 if self.config.enable_bottleneck_detection {
510 result.bottleneck_analysis = Some(self.analyze_bottlenecks(graph)?);
511 }
512
513 if self.config.enable_dataflow_analysis {
514 result.dataflow_analysis = Some(self.analyze_dataflow(graph)?);
515 }
516
517 result.critical_path = self.find_critical_path(graph)?;
518 result.recommendations = self.generate_recommendations(&result)?;
519
520 self.analysis_results.insert(graph_id, result.clone());
521 Ok(result)
522 }
523
524 pub fn get_analysis_result(&self, graph_id: Uuid) -> Option<&GraphAnalysisResult> {
526 self.analysis_results.get(&graph_id)
527 }
528
529 pub fn export_to_dot(&self, graph_id: Uuid) -> Result<String> {
531 let graph = self
532 .graphs
533 .get(&graph_id)
534 .ok_or_else(|| anyhow::anyhow!("Graph not found: {}", graph_id))?;
535
536 let mut dot = String::new();
537 dot.push_str(&format!("digraph \"{}\" {{\n", graph.metadata.name));
538 dot.push_str(" rankdir=TB;\n");
539 dot.push_str(" node [shape=box, style=filled];\n\n");
540
541 for node in graph.nodes.values() {
543 let color = self.get_node_color(&node.operation_type);
544 let label = format!(
545 "{}\\n{}\\n{:.1} GFLOP\\n{:.1} MB",
546 node.name,
547 format!("{:?}", node.operation_type),
548 node.flop_count as f64 / 1e9,
549 node.memory_usage as f64 / (1024.0 * 1024.0)
550 );
551
552 dot.push_str(&format!(
553 " \"{}\" [label=\"{}\", fillcolor=\"{}\"];\n",
554 node.id, label, color
555 ));
556 }
557
558 dot.push('\n');
559
560 for (node_id, dependencies) in &graph.edges {
562 for dep in dependencies {
563 dot.push_str(&format!(" \"{}\" -> \"{}\";\n", dep, node_id));
564 }
565 }
566
567 dot.push_str("}\n");
568 Ok(dot)
569 }
570
571 fn calculate_depth_and_topo_order(
574 &self,
575 nodes: &mut HashMap<String, GraphNode>,
576 edges: &HashMap<String, Vec<String>>,
577 ) -> Result<()> {
578 let mut in_degree: HashMap<String, usize> = HashMap::new();
580 let mut adj_list: HashMap<String, Vec<String>> = HashMap::new();
581
582 for node_id in nodes.keys() {
584 in_degree.insert(node_id.clone(), 0);
585 adj_list.insert(node_id.clone(), Vec::new());
586 }
587
588 for (node_id, dependencies) in edges {
589 in_degree.insert(node_id.clone(), dependencies.len());
590 for dep in dependencies {
591 adj_list.get_mut(dep).unwrap().push(node_id.clone());
592 }
593 }
594
595 let mut queue = VecDeque::new();
597 let mut topo_order = 0;
598
599 for (node_id, °ree) in &in_degree {
601 if degree == 0 {
602 queue.push_back((node_id.clone(), 0)); }
604 }
605
606 while let Some((node_id, depth)) = queue.pop_front() {
607 if let Some(node) = nodes.get_mut(&node_id) {
609 node.depth = depth;
610 node.topo_order = Some(topo_order);
611 topo_order += 1;
612 }
613
614 if let Some(neighbors) = adj_list.get(&node_id) {
616 for neighbor in neighbors {
617 *in_degree.get_mut(neighbor).unwrap() -= 1;
618 if in_degree[neighbor] == 0 {
619 queue.push_back((neighbor.clone(), depth + 1));
620 }
621 }
622 }
623 }
624
625 Ok(())
626 }
627
628 fn estimate_flops(&self, op_type: &OperationType, shapes: &[Vec<usize>]) -> u64 {
629 match op_type {
631 OperationType::MatMul => {
632 if shapes.len() >= 2 {
633 let a_shape = &shapes[0];
634 let b_shape = &shapes[1];
635 if a_shape.len() >= 2 && b_shape.len() >= 2 {
636 let m = a_shape[a_shape.len() - 2];
637 let k = a_shape[a_shape.len() - 1];
638 let n = b_shape[b_shape.len() - 1];
639 return (2 * m * k * n) as u64;
640 }
641 }
642 1000000 },
644 OperationType::Add | OperationType::Subtract | OperationType::Multiply => {
645 shapes.first().map(|s| s.iter().product::<usize>() as u64).unwrap_or(1000)
646 },
647 OperationType::ReLU | OperationType::Sigmoid | OperationType::Tanh => {
648 shapes.first().map(|s| s.iter().product::<usize>() as u64).unwrap_or(1000)
649 },
650 OperationType::LayerNorm | OperationType::BatchNorm => {
651 shapes.first().map(|s| (s.iter().product::<usize>() * 5) as u64).unwrap_or(5000)
652 },
653 _ => 1000, }
655 }
656
657 fn estimate_memory(&self, op_type: &OperationType, shapes: &[Vec<usize>]) -> u64 {
658 let element_size = 4u64;
660 match op_type {
661 OperationType::MatMul => {
662 shapes
663 .iter()
664 .map(|s| s.iter().product::<usize>() as u64 * element_size)
665 .sum::<u64>()
666 .max(1024) },
668 _ => shapes
669 .first()
670 .map(|s| s.iter().product::<usize>() as u64 * element_size)
671 .unwrap_or(1024),
672 }
673 }
674
675 fn estimate_parameters(&self, op_type: &OperationType) -> Option<u64> {
676 match op_type {
677 OperationType::MatMul => Some(1000000), OperationType::Conv2D => Some(500000),
679 OperationType::Embedding => Some(2000000),
680 OperationType::LayerNorm => Some(1000),
681 _ => None,
682 }
683 }
684
685 fn analyze_memory_usage(&self, graph: &ComputationGraph) -> Result<MemoryAnalysis> {
686 let total_memory_usage = graph.nodes.values().map(|n| n.memory_usage).sum();
687
688 let mut memory_by_operation: HashMap<OperationType, u64> = HashMap::new();
689 for node in graph.nodes.values() {
690 *memory_by_operation.entry(node.operation_type.clone()).or_insert(0) +=
691 node.memory_usage;
692 }
693
694 let mut memory_hotspots: Vec<(String, u64)> =
695 graph.nodes.values().map(|n| (n.id.clone(), n.memory_usage)).collect();
696 memory_hotspots.sort_by(|a, b| b.1.cmp(&a.1));
697 memory_hotspots.truncate(10); let peak_memory_usage = total_memory_usage; let fragmentation_ratio = 0.1; let optimization_suggestions = vec![
703 "Consider memory pooling for frequently allocated tensors".to_string(),
704 "Implement in-place operations where possible".to_string(),
705 "Use gradient checkpointing for memory-intensive layers".to_string(),
706 ];
707
708 Ok(MemoryAnalysis {
709 total_memory_usage,
710 peak_memory_usage,
711 memory_by_operation,
712 memory_hotspots,
713 fragmentation_ratio,
714 optimization_suggestions,
715 })
716 }
717
718 fn analyze_flop_usage(&self, graph: &ComputationGraph) -> Result<FlopAnalysis> {
719 let total_flops = graph.nodes.values().map(|n| n.flop_count).sum();
720
721 let mut flops_by_operation: HashMap<OperationType, u64> = HashMap::new();
722 for node in graph.nodes.values() {
723 *flops_by_operation.entry(node.operation_type.clone()).or_insert(0) += node.flop_count;
724 }
725
726 let mut compute_hotspots: Vec<(String, u64)> =
727 graph.nodes.values().map(|n| (n.id.clone(), n.flop_count)).collect();
728 compute_hotspots.sort_by(|a, b| b.1.cmp(&a.1));
729 compute_hotspots.truncate(10); let total_memory = graph.nodes.values().map(|n| n.memory_usage).sum::<u64>();
732 let arithmetic_intensity =
733 if total_memory > 0 { total_flops as f64 / total_memory as f64 } else { 0.0 };
734
735 let complexity_analysis = ComplexityAnalysis {
736 time_complexity: "O(n)".to_string(), space_complexity: "O(n)".to_string(), parallelization_potential: 0.7, sequential_dependencies: graph.metadata.max_depth,
740 };
741
742 Ok(FlopAnalysis {
743 total_flops,
744 flops_by_operation,
745 compute_hotspots,
746 arithmetic_intensity,
747 complexity_analysis,
748 })
749 }
750
751 fn detect_optimization_opportunities(
752 &self,
753 graph: &ComputationGraph,
754 ) -> Result<Vec<OptimizationOpportunity>> {
755 let mut opportunities = Vec::new();
756
757 opportunities.extend(self.detect_fusion_opportunities(graph)?);
759
760 opportunities.extend(self.detect_redundancy_opportunities(graph)?);
762
763 opportunities.extend(self.detect_memory_optimizations(graph)?);
765
766 Ok(opportunities)
767 }
768
769 fn detect_fusion_opportunities(
770 &self,
771 graph: &ComputationGraph,
772 ) -> Result<Vec<OptimizationOpportunity>> {
773 let mut opportunities = Vec::new();
774
775 for node in graph.nodes.values() {
777 if let OperationType::Add = node.operation_type {
778 let empty_deps = vec![];
779 let dependencies = graph.edges.get(&node.id).unwrap_or(&empty_deps);
780 for dep in dependencies {
781 if let Some(dep_node) = graph.nodes.get(dep) {
782 if let OperationType::MatMul = dep_node.operation_type {
783 opportunities.push(OptimizationOpportunity {
784 optimization_type: OptimizationType::OperationFusion,
785 description:
786 "Fuse MatMul and Add operations into a single GEMM operation"
787 .to_string(),
788 affected_nodes: vec![dep.clone(), node.id.clone()],
789 estimated_improvement: EstimatedImprovement {
790 speedup_factor: 1.2,
791 memory_reduction: 1024 * 1024, energy_savings: 0.1,
793 },
794 implementation_difficulty: 2,
795 priority: OptimizationPriority::Medium,
796 });
797 }
798 }
799 }
800 }
801 }
802
803 Ok(opportunities)
804 }
805
806 fn detect_redundancy_opportunities(
807 &self,
808 _graph: &ComputationGraph,
809 ) -> Result<Vec<OptimizationOpportunity>> {
810 Ok(vec![])
812 }
813
814 fn detect_memory_optimizations(
815 &self,
816 graph: &ComputationGraph,
817 ) -> Result<Vec<OptimizationOpportunity>> {
818 let mut opportunities = Vec::new();
819
820 for node in graph.nodes.values() {
822 if node.memory_usage > self.config.large_memory_threshold {
823 opportunities.push(OptimizationOpportunity {
824 optimization_type: OptimizationType::MemoryLayoutOptimization,
825 description: format!(
826 "Optimize memory layout for large operation: {}",
827 node.name
828 ),
829 affected_nodes: vec![node.id.clone()],
830 estimated_improvement: EstimatedImprovement {
831 speedup_factor: 1.1,
832 memory_reduction: node.memory_usage / 4, energy_savings: 0.05,
834 },
835 implementation_difficulty: 3,
836 priority: OptimizationPriority::Medium,
837 });
838 }
839 }
840
841 Ok(opportunities)
842 }
843
844 fn analyze_bottlenecks(&self, graph: &ComputationGraph) -> Result<BottleneckAnalysis> {
845 let mut bottleneck_nodes = Vec::new();
846 let mut parallelizable_nodes = Vec::new();
847
848 for node in graph.nodes.values() {
849 if let Some(exec_time) = node.execution_time_us {
850 if exec_time > self.config.bottleneck_threshold_us {
851 bottleneck_nodes.push(node.id.clone());
852 }
853 }
854
855 match node.operation_type {
857 OperationType::MatMul | OperationType::Conv2D | OperationType::Add => {
858 parallelizable_nodes.push(node.id.clone());
859 },
860 _ => {},
861 }
862 }
863
864 let critical_path_nodes = self.find_critical_path(graph)?;
865 let critical_path_time_us = critical_path_nodes
866 .iter()
867 .filter_map(|id| graph.nodes.get(id))
868 .filter_map(|node| node.execution_time_us)
869 .sum();
870
871 let scheduling_suggestions = vec![
872 "Consider parallel execution of independent operations".to_string(),
873 "Use asynchronous execution for I/O operations".to_string(),
874 "Implement pipeline parallelism for sequential operations".to_string(),
875 ];
876
877 Ok(BottleneckAnalysis {
878 bottleneck_nodes,
879 critical_path_nodes,
880 critical_path_time_us,
881 parallelizable_nodes,
882 scheduling_suggestions,
883 })
884 }
885
886 fn analyze_dataflow(&self, graph: &ComputationGraph) -> Result<DataFlowAnalysis> {
887 let mut data_dependencies = HashMap::new();
888 let mut live_variables = HashMap::new();
889 let mut variable_lifetimes = HashMap::new();
890
891 for (node_id, dependencies) in &graph.edges {
893 data_dependencies.insert(node_id.clone(), dependencies.clone());
894 live_variables.insert(node_id.clone(), dependencies.iter().cloned().collect());
895
896 for dep in dependencies {
898 if !variable_lifetimes.contains_key(dep) {
899 variable_lifetimes.insert(
900 dep.clone(),
901 VariableLifetime {
902 birth_node: dep.clone(),
903 death_node: node_id.clone(),
904 usage_nodes: vec![node_id.clone()],
905 memory_footprint: graph
906 .nodes
907 .get(dep)
908 .map(|n| n.memory_usage)
909 .unwrap_or(0),
910 },
911 );
912 } else {
913 let lifetime = variable_lifetimes.get_mut(dep).unwrap();
914 lifetime.death_node = node_id.clone();
915 lifetime.usage_nodes.push(node_id.clone());
916 }
917 }
918 }
919
920 let memory_reuse_opportunities = vec![MemoryReuseOpportunity {
921 reusable_variables: vec!["var1".to_string(), "var2".to_string()],
922 memory_savings: 1024 * 1024, complexity: 2,
924 }];
925
926 Ok(DataFlowAnalysis {
927 data_dependencies,
928 live_variables,
929 variable_lifetimes,
930 memory_reuse_opportunities,
931 })
932 }
933
934 fn find_critical_path(&self, graph: &ComputationGraph) -> Result<Vec<String>> {
935 let mut path = Vec::new();
937 let mut current_depth = graph.metadata.max_depth;
938
939 while current_depth > 0 {
940 for node in graph.nodes.values() {
942 if node.depth == current_depth {
943 path.push(node.id.clone());
944 current_depth -= 1;
945 break;
946 }
947 }
948 current_depth = current_depth.saturating_sub(1);
949 }
950
951 path.reverse();
952 Ok(path)
953 }
954
955 fn calculate_statistics(&self, graph: &ComputationGraph) -> Result<GraphStatistics> {
956 let mut nodes_by_type: HashMap<OperationType, usize> = HashMap::new();
957 for node in graph.nodes.values() {
958 *nodes_by_type.entry(node.operation_type.clone()).or_insert(0) += 1;
959 }
960
961 let total_fan_in: usize = graph.edges.values().map(|deps| deps.len()).sum();
962 let total_fan_out = total_fan_in; let average_fan_in = total_fan_in as f64 / graph.nodes.len() as f64;
964 let average_fan_out = total_fan_out as f64 / graph.nodes.len() as f64;
965
966 Ok(GraphStatistics {
967 nodes_by_type,
968 average_fan_in,
969 average_fan_out,
970 diameter: graph.metadata.max_depth,
971 clustering_coefficient: 0.0, strongly_connected_components: graph.nodes.len(), })
974 }
975
976 fn generate_recommendations(&self, analysis: &GraphAnalysisResult) -> Result<Vec<String>> {
977 let mut recommendations = Vec::new();
978
979 if let Some(ref memory_analysis) = analysis.memory_analysis {
981 if memory_analysis.total_memory_usage > 1024 * 1024 * 1024 {
982 recommendations.push(
984 "Consider using gradient checkpointing to reduce memory usage".to_string(),
985 );
986 }
987 if memory_analysis.fragmentation_ratio > 0.2 {
988 recommendations
989 .push("Implement memory pooling to reduce fragmentation".to_string());
990 }
991 }
992
993 if let Some(ref flop_analysis) = analysis.flop_analysis {
995 if flop_analysis.arithmetic_intensity < 1.0 {
996 recommendations
997 .push("Consider kernel fusion to improve arithmetic intensity".to_string());
998 }
999 if flop_analysis.complexity_analysis.parallelization_potential > 0.5 {
1000 recommendations.push(
1001 "Explore parallelization opportunities for compute-intensive operations"
1002 .to_string(),
1003 );
1004 }
1005 }
1006
1007 if analysis.optimization_opportunities.len() > 3 {
1009 recommendations.push(
1010 "Multiple optimization opportunities detected - prioritize by estimated impact"
1011 .to_string(),
1012 );
1013 }
1014
1015 if let Some(ref bottleneck_analysis) = analysis.bottleneck_analysis {
1017 if !bottleneck_analysis.bottleneck_nodes.is_empty() {
1018 recommendations.push(
1019 "Address bottleneck operations through optimization or parallelization"
1020 .to_string(),
1021 );
1022 }
1023 }
1024
1025 Ok(recommendations)
1026 }
1027
1028 fn get_node_color(&self, op_type: &OperationType) -> &'static str {
1029 match op_type {
1030 OperationType::MatMul | OperationType::Dot => "lightblue",
1031 OperationType::Add
1032 | OperationType::Subtract
1033 | OperationType::Multiply
1034 | OperationType::Divide => "lightgreen",
1035 OperationType::ReLU
1036 | OperationType::Sigmoid
1037 | OperationType::Tanh
1038 | OperationType::GELU => "orange",
1039 OperationType::LayerNorm | OperationType::BatchNorm | OperationType::RMSNorm => {
1040 "yellow"
1041 },
1042 OperationType::Conv1D | OperationType::Conv2D | OperationType::Conv3D => "lightcoral",
1043 OperationType::Attention | OperationType::MultiHeadAttention => "purple",
1044 OperationType::Embedding | OperationType::PositionalEmbedding => "pink",
1045 _ => "lightgray",
1046 }
1047 }
1048}
1049
1050impl fmt::Display for OperationType {
1051 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1052 match self {
1053 OperationType::Custom(name) => write!(f, "Custom({})", name),
1054 _ => write!(f, "{:?}", self),
1055 }
1056 }
1057}
1058
1059impl Default for ComputationGraphAnalyzer {
1060 fn default() -> Self {
1061 Self::new(GraphAnalysisConfig::default())
1062 }
1063}
1064
1065#[cfg(test)]
1066mod tests {
1067 use super::*;
1068
1069 #[test]
1070 fn test_computation_graph_creation() {
1071 let mut analyzer = ComputationGraphAnalyzer::default();
1072
1073 let operations = vec![
1074 (
1075 "input".to_string(),
1076 OperationType::Custom("Input".to_string()),
1077 vec![],
1078 ),
1079 (
1080 "linear1".to_string(),
1081 OperationType::MatMul,
1082 vec!["input".to_string()],
1083 ),
1084 (
1085 "relu1".to_string(),
1086 OperationType::ReLU,
1087 vec!["linear1".to_string()],
1088 ),
1089 (
1090 "linear2".to_string(),
1091 OperationType::MatMul,
1092 vec!["relu1".to_string()],
1093 ),
1094 (
1095 "output".to_string(),
1096 OperationType::Custom("Output".to_string()),
1097 vec!["linear2".to_string()],
1098 ),
1099 ];
1100
1101 let graph_id = analyzer.create_graph("test_model".to_string(), operations).unwrap();
1102 let analysis = analyzer.analyze_graph(graph_id).unwrap();
1103
1104 assert_eq!(analysis.statistics.nodes_by_type.len(), 4); assert!(analysis.critical_path.len() > 0);
1106 }
1107
1108 #[test]
1109 fn test_optimization_detection() {
1110 let mut analyzer = ComputationGraphAnalyzer::default();
1111
1112 let operations = vec![
1113 (
1114 "input".to_string(),
1115 OperationType::Custom("Input".to_string()),
1116 vec![],
1117 ),
1118 (
1119 "matmul".to_string(),
1120 OperationType::MatMul,
1121 vec!["input".to_string()],
1122 ),
1123 (
1124 "add".to_string(),
1125 OperationType::Add,
1126 vec!["matmul".to_string()],
1127 ),
1128 ];
1129
1130 let graph_id = analyzer.create_graph("fusion_test".to_string(), operations).unwrap();
1131 let analysis = analyzer.analyze_graph(graph_id).unwrap();
1132
1133 assert!(analysis
1134 .optimization_opportunities
1135 .iter()
1136 .any(|op| op.optimization_type == OptimizationType::OperationFusion));
1137 }
1138
1139 #[test]
1140 fn test_dot_export() {
1141 let mut analyzer = ComputationGraphAnalyzer::default();
1142
1143 let operations = vec![
1144 ("a".to_string(), OperationType::MatMul, vec![]),
1145 ("b".to_string(), OperationType::ReLU, vec!["a".to_string()]),
1146 ];
1147
1148 let graph_id = analyzer.create_graph("simple".to_string(), operations).unwrap();
1149 let dot = analyzer.export_to_dot(graph_id).unwrap();
1150
1151 assert!(dot.contains("digraph"));
1152 assert!(dot.contains("MatMul"));
1153 assert!(dot.contains("ReLU"));
1154 }
1155}