1use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet, VecDeque};
10use std::fmt;
11use uuid::Uuid;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ComputationGraph {
16 pub id: Uuid,
18 pub nodes: HashMap<String, GraphNode>,
20 pub edges: HashMap<String, Vec<String>>,
22 pub root_nodes: HashSet<String>,
24 pub leaf_nodes: HashSet<String>,
26 pub metadata: GraphMetadata,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct GraphMetadata {
33 pub name: String,
35 pub node_count: usize,
37 pub edge_count: usize,
39 pub max_depth: usize,
41 pub estimated_memory_usage: u64,
43 pub estimated_flops: u64,
45 pub created_at: chrono::DateTime<chrono::Utc>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct GraphNode {
52 pub id: String,
54 pub name: String,
56 pub operation_type: OperationType,
58 pub input_shapes: Vec<Vec<usize>>,
60 pub output_shapes: Vec<Vec<usize>>,
62 pub flop_count: u64,
64 pub memory_usage: u64,
66 pub execution_time_us: Option<u64>,
68 pub parameter_count: Option<u64>,
70 pub topo_order: Option<usize>,
72 pub depth: usize,
74 pub metadata: HashMap<String, String>,
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
80pub enum OperationType {
81 Add,
83 Subtract,
84 Multiply,
85 Divide,
86 MatMul,
87 Dot,
88
89 ReLU,
91 Sigmoid,
92 Tanh,
93 GELU,
94 Softmax,
95
96 LayerNorm,
98 BatchNorm,
99 RMSNorm,
100
101 Conv1D,
103 Conv2D,
104 Conv3D,
105 ConvTranspose,
106
107 MaxPool,
109 AvgPool,
110 AdaptivePool,
111
112 Reshape,
114 Transpose,
115 Concat,
116 Split,
117 Slice,
118 Gather,
119 Scatter,
120
121 Sum,
123 Mean,
124 Max,
125 Min,
126
127 Attention,
129 MultiHeadAttention,
130 SelfAttention,
131 CrossAttention,
132
133 Embedding,
135 PositionalEmbedding,
136
137 CrossEntropyLoss,
139 MSELoss,
140 L1Loss,
141
142 If,
144 While,
145 Loop,
146
147 Custom(String),
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct GraphAnalysisConfig {
154 pub enable_memory_analysis: bool,
156 pub enable_flop_analysis: bool,
158 pub enable_optimization_analysis: bool,
160 pub enable_bottleneck_detection: bool,
162 pub enable_dataflow_analysis: bool,
164 pub bottleneck_threshold_us: u64,
166 pub large_memory_threshold: u64,
168}
169
170impl Default for GraphAnalysisConfig {
171 fn default() -> Self {
172 Self {
173 enable_memory_analysis: true,
174 enable_flop_analysis: true,
175 enable_optimization_analysis: true,
176 enable_bottleneck_detection: true,
177 enable_dataflow_analysis: true,
178 bottleneck_threshold_us: 1000, large_memory_threshold: 1024 * 1024 * 100, }
181 }
182}
183
184#[derive(Debug)]
186pub struct ComputationGraphAnalyzer {
187 config: GraphAnalysisConfig,
188 graphs: HashMap<Uuid, ComputationGraph>,
189 analysis_results: HashMap<Uuid, GraphAnalysisResult>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct GraphAnalysisResult {
195 pub graph_id: Uuid,
197 pub memory_analysis: Option<MemoryAnalysis>,
199 pub flop_analysis: Option<FlopAnalysis>,
201 pub optimization_opportunities: Vec<OptimizationOpportunity>,
203 pub bottleneck_analysis: Option<BottleneckAnalysis>,
205 pub dataflow_analysis: Option<DataFlowAnalysis>,
207 pub critical_path: Vec<String>,
209 pub statistics: GraphStatistics,
211 pub recommendations: Vec<String>,
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct MemoryAnalysis {
218 pub total_memory_usage: u64,
220 pub peak_memory_usage: u64,
222 pub memory_by_operation: HashMap<OperationType, u64>,
224 pub memory_hotspots: Vec<(String, u64)>,
226 pub fragmentation_ratio: f64,
228 pub optimization_suggestions: Vec<String>,
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
234pub struct FlopAnalysis {
235 pub total_flops: u64,
237 pub flops_by_operation: HashMap<OperationType, u64>,
239 pub compute_hotspots: Vec<(String, u64)>,
241 pub arithmetic_intensity: f64,
243 pub complexity_analysis: ComplexityAnalysis,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct ComplexityAnalysis {
250 pub time_complexity: String,
252 pub space_complexity: String,
254 pub parallelization_potential: f64,
256 pub sequential_dependencies: usize,
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct OptimizationOpportunity {
263 pub optimization_type: OptimizationType,
265 pub description: String,
267 pub affected_nodes: Vec<String>,
269 pub estimated_improvement: EstimatedImprovement,
271 pub implementation_difficulty: u8,
273 pub priority: OptimizationPriority,
275}
276
277#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
279pub enum OptimizationType {
280 OperationFusion,
282 RedundancyElimination,
284 MemoryLayoutOptimization,
286 AlgorithmicOptimization,
288 Parallelization,
290 DataAccessOptimization,
292 PrecisionOptimization,
294 Memoization,
296 ControlFlowOptimization,
298}
299
300#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
302pub enum OptimizationPriority {
303 Low,
304 Medium,
305 High,
306 Critical,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct EstimatedImprovement {
312 pub speedup_factor: f64,
314 pub memory_reduction: u64,
316 pub energy_savings: f64,
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize)]
322pub struct BottleneckAnalysis {
323 pub bottleneck_nodes: Vec<String>,
325 pub critical_path_nodes: Vec<String>,
327 pub critical_path_time_us: u64,
329 pub parallelizable_nodes: Vec<String>,
331 pub scheduling_suggestions: Vec<String>,
333}
334
335#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct DataFlowAnalysis {
338 pub data_dependencies: HashMap<String, Vec<String>>,
340 pub live_variables: HashMap<String, HashSet<String>>,
342 pub variable_lifetimes: HashMap<String, VariableLifetime>,
344 pub memory_reuse_opportunities: Vec<MemoryReuseOpportunity>,
346}
347
348#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct VariableLifetime {
351 pub birth_node: String,
353 pub death_node: String,
355 pub usage_nodes: Vec<String>,
357 pub memory_footprint: u64,
359}
360
361#[derive(Debug, Clone, Serialize, Deserialize)]
363pub struct MemoryReuseOpportunity {
364 pub reusable_variables: Vec<String>,
366 pub memory_savings: u64,
368 pub complexity: u8,
370}
371
372#[derive(Debug, Clone, Serialize, Deserialize)]
374pub struct GraphStatistics {
375 pub nodes_by_type: HashMap<OperationType, usize>,
377 pub average_fan_in: f64,
379 pub average_fan_out: f64,
381 pub diameter: usize,
383 pub clustering_coefficient: f64,
385 pub strongly_connected_components: usize,
387}
388
389impl ComputationGraphAnalyzer {
390 pub fn new(config: GraphAnalysisConfig) -> Self {
392 Self {
393 config,
394 graphs: HashMap::new(),
395 analysis_results: HashMap::new(),
396 }
397 }
398
399 pub fn add_graph(&mut self, graph: ComputationGraph) -> Result<()> {
401 let graph_id = graph.id;
402 self.graphs.insert(graph_id, graph);
403 Ok(())
404 }
405
406 pub fn create_graph(
408 &mut self,
409 name: String,
410 operations: Vec<(String, OperationType, Vec<String>)>, ) -> Result<Uuid> {
412 let graph_id = Uuid::new_v4();
413 let mut nodes = HashMap::new();
414 let mut edges = HashMap::new();
415 let mut root_nodes = HashSet::new();
416 let mut leaf_nodes = HashSet::new();
417
418 for (node_id, op_type, dependencies) in &operations {
420 let node = GraphNode {
421 id: node_id.clone(),
422 name: node_id.clone(),
423 operation_type: op_type.clone(),
424 input_shapes: vec![],
425 output_shapes: vec![],
426 flop_count: self.estimate_flops(op_type, &[]),
427 memory_usage: self.estimate_memory(op_type, &[]),
428 execution_time_us: None,
429 parameter_count: self.estimate_parameters(op_type),
430 topo_order: None,
431 depth: 0,
432 metadata: HashMap::new(),
433 };
434 nodes.insert(node_id.clone(), node);
435
436 if dependencies.is_empty() {
438 root_nodes.insert(node_id.clone());
439 }
440 edges.insert(node_id.clone(), dependencies.clone());
441 }
442
443 let all_dependencies: HashSet<String> = edges.values().flatten().cloned().collect();
445 for node_id in nodes.keys() {
446 if !all_dependencies.contains(node_id) {
447 leaf_nodes.insert(node_id.clone());
448 }
449 }
450
451 self.calculate_depth_and_topo_order(&mut nodes, &edges)?;
453
454 let metadata = GraphMetadata {
455 name,
456 node_count: nodes.len(),
457 edge_count: edges.values().map(|deps| deps.len()).sum(),
458 max_depth: nodes.values().map(|n| n.depth).max().unwrap_or(0),
459 estimated_memory_usage: nodes.values().map(|n| n.memory_usage).sum(),
460 estimated_flops: nodes.values().map(|n| n.flop_count).sum(),
461 created_at: chrono::Utc::now(),
462 };
463
464 let graph = ComputationGraph {
465 id: graph_id,
466 nodes,
467 edges,
468 root_nodes,
469 leaf_nodes,
470 metadata,
471 };
472
473 self.graphs.insert(graph_id, graph);
474 Ok(graph_id)
475 }
476
477 pub fn analyze_graph(&mut self, graph_id: Uuid) -> Result<GraphAnalysisResult> {
479 let graph = self
480 .graphs
481 .get(&graph_id)
482 .ok_or_else(|| anyhow::anyhow!("Graph not found: {}", graph_id))?;
483
484 let mut result = GraphAnalysisResult {
485 graph_id,
486 memory_analysis: None,
487 flop_analysis: None,
488 optimization_opportunities: Vec::new(),
489 bottleneck_analysis: None,
490 dataflow_analysis: None,
491 critical_path: Vec::new(),
492 statistics: self.calculate_statistics(graph)?,
493 recommendations: Vec::new(),
494 };
495
496 if self.config.enable_memory_analysis {
498 result.memory_analysis = Some(self.analyze_memory_usage(graph)?);
499 }
500
501 if self.config.enable_flop_analysis {
502 result.flop_analysis = Some(self.analyze_flop_usage(graph)?);
503 }
504
505 if self.config.enable_optimization_analysis {
506 result.optimization_opportunities = self.detect_optimization_opportunities(graph)?;
507 }
508
509 if self.config.enable_bottleneck_detection {
510 result.bottleneck_analysis = Some(self.analyze_bottlenecks(graph)?);
511 }
512
513 if self.config.enable_dataflow_analysis {
514 result.dataflow_analysis = Some(self.analyze_dataflow(graph)?);
515 }
516
517 result.critical_path = self.find_critical_path(graph)?;
518 result.recommendations = self.generate_recommendations(&result)?;
519
520 self.analysis_results.insert(graph_id, result.clone());
521 Ok(result)
522 }
523
524 pub fn get_analysis_result(&self, graph_id: Uuid) -> Option<&GraphAnalysisResult> {
526 self.analysis_results.get(&graph_id)
527 }
528
529 pub fn export_to_dot(&self, graph_id: Uuid) -> Result<String> {
531 let graph = self
532 .graphs
533 .get(&graph_id)
534 .ok_or_else(|| anyhow::anyhow!("Graph not found: {}", graph_id))?;
535
536 let mut dot = String::new();
537 dot.push_str(&format!("digraph \"{}\" {{\n", graph.metadata.name));
538 dot.push_str(" rankdir=TB;\n");
539 dot.push_str(" node [shape=box, style=filled];\n\n");
540
541 for node in graph.nodes.values() {
543 let color = self.get_node_color(&node.operation_type);
544 let label = format!(
545 "{}\\n{}\\n{:.1} GFLOP\\n{:.1} MB",
546 node.name,
547 format!("{:?}", node.operation_type),
548 node.flop_count as f64 / 1e9,
549 node.memory_usage as f64 / (1024.0 * 1024.0)
550 );
551
552 dot.push_str(&format!(
553 " \"{}\" [label=\"{}\", fillcolor=\"{}\"];\n",
554 node.id, label, color
555 ));
556 }
557
558 dot.push('\n');
559
560 for (node_id, dependencies) in &graph.edges {
562 for dep in dependencies {
563 dot.push_str(&format!(" \"{}\" -> \"{}\";\n", dep, node_id));
564 }
565 }
566
567 dot.push_str("}\n");
568 Ok(dot)
569 }
570
571 fn calculate_depth_and_topo_order(
574 &self,
575 nodes: &mut HashMap<String, GraphNode>,
576 edges: &HashMap<String, Vec<String>>,
577 ) -> Result<()> {
578 let mut in_degree: HashMap<String, usize> = HashMap::new();
580 let mut adj_list: HashMap<String, Vec<String>> = HashMap::new();
581
582 for node_id in nodes.keys() {
584 in_degree.insert(node_id.clone(), 0);
585 adj_list.insert(node_id.clone(), Vec::new());
586 }
587
588 for (node_id, dependencies) in edges {
589 in_degree.insert(node_id.clone(), dependencies.len());
590 for dep in dependencies {
591 if let Some(adj) = adj_list.get_mut(dep) {
592 adj.push(node_id.clone());
593 }
594 }
595 }
596
597 let mut queue = VecDeque::new();
599 let mut topo_order = 0;
600
601 for (node_id, °ree) in &in_degree {
603 if degree == 0 {
604 queue.push_back((node_id.clone(), 0)); }
606 }
607
608 while let Some((node_id, depth)) = queue.pop_front() {
609 if let Some(node) = nodes.get_mut(&node_id) {
611 node.depth = depth;
612 node.topo_order = Some(topo_order);
613 topo_order += 1;
614 }
615
616 if let Some(neighbors) = adj_list.get(&node_id) {
618 for neighbor in neighbors {
619 if let Some(degree) = in_degree.get_mut(neighbor) {
620 *degree -= 1;
621 if *degree == 0 {
622 queue.push_back((neighbor.clone(), depth + 1));
623 }
624 }
625 }
626 }
627 }
628
629 Ok(())
630 }
631
632 fn estimate_flops(&self, op_type: &OperationType, shapes: &[Vec<usize>]) -> u64 {
633 match op_type {
635 OperationType::MatMul => {
636 if shapes.len() >= 2 {
637 let a_shape = &shapes[0];
638 let b_shape = &shapes[1];
639 if a_shape.len() >= 2 && b_shape.len() >= 2 {
640 let m = a_shape[a_shape.len() - 2];
641 let k = a_shape[a_shape.len() - 1];
642 let n = b_shape[b_shape.len() - 1];
643 return (2 * m * k * n) as u64;
644 }
645 }
646 1000000 },
648 OperationType::Add | OperationType::Subtract | OperationType::Multiply => {
649 shapes.first().map(|s| s.iter().product::<usize>() as u64).unwrap_or(1000)
650 },
651 OperationType::ReLU | OperationType::Sigmoid | OperationType::Tanh => {
652 shapes.first().map(|s| s.iter().product::<usize>() as u64).unwrap_or(1000)
653 },
654 OperationType::LayerNorm | OperationType::BatchNorm => {
655 shapes.first().map(|s| (s.iter().product::<usize>() * 5) as u64).unwrap_or(5000)
656 },
657 _ => 1000, }
659 }
660
661 fn estimate_memory(&self, op_type: &OperationType, shapes: &[Vec<usize>]) -> u64 {
662 let element_size = 4u64;
664 match op_type {
665 OperationType::MatMul => {
666 shapes
667 .iter()
668 .map(|s| s.iter().product::<usize>() as u64 * element_size)
669 .sum::<u64>()
670 .max(1024) },
672 _ => shapes
673 .first()
674 .map(|s| s.iter().product::<usize>() as u64 * element_size)
675 .unwrap_or(1024),
676 }
677 }
678
679 fn estimate_parameters(&self, op_type: &OperationType) -> Option<u64> {
680 match op_type {
681 OperationType::MatMul => Some(1000000), OperationType::Conv2D => Some(500000),
683 OperationType::Embedding => Some(2000000),
684 OperationType::LayerNorm => Some(1000),
685 _ => None,
686 }
687 }
688
689 fn analyze_memory_usage(&self, graph: &ComputationGraph) -> Result<MemoryAnalysis> {
690 let total_memory_usage = graph.nodes.values().map(|n| n.memory_usage).sum();
691
692 let mut memory_by_operation: HashMap<OperationType, u64> = HashMap::new();
693 for node in graph.nodes.values() {
694 *memory_by_operation.entry(node.operation_type.clone()).or_insert(0) +=
695 node.memory_usage;
696 }
697
698 let mut memory_hotspots: Vec<(String, u64)> =
699 graph.nodes.values().map(|n| (n.id.clone(), n.memory_usage)).collect();
700 memory_hotspots.sort_by_key(|item| std::cmp::Reverse(item.1));
701 memory_hotspots.truncate(10); let peak_memory_usage = total_memory_usage; let fragmentation_ratio = 0.1; let optimization_suggestions = vec![
707 "Consider memory pooling for frequently allocated tensors".to_string(),
708 "Implement in-place operations where possible".to_string(),
709 "Use gradient checkpointing for memory-intensive layers".to_string(),
710 ];
711
712 Ok(MemoryAnalysis {
713 total_memory_usage,
714 peak_memory_usage,
715 memory_by_operation,
716 memory_hotspots,
717 fragmentation_ratio,
718 optimization_suggestions,
719 })
720 }
721
722 fn analyze_flop_usage(&self, graph: &ComputationGraph) -> Result<FlopAnalysis> {
723 let total_flops = graph.nodes.values().map(|n| n.flop_count).sum();
724
725 let mut flops_by_operation: HashMap<OperationType, u64> = HashMap::new();
726 for node in graph.nodes.values() {
727 *flops_by_operation.entry(node.operation_type.clone()).or_insert(0) += node.flop_count;
728 }
729
730 let mut compute_hotspots: Vec<(String, u64)> =
731 graph.nodes.values().map(|n| (n.id.clone(), n.flop_count)).collect();
732 compute_hotspots.sort_by_key(|item| std::cmp::Reverse(item.1));
733 compute_hotspots.truncate(10); let total_memory = graph.nodes.values().map(|n| n.memory_usage).sum::<u64>();
736 let arithmetic_intensity =
737 if total_memory > 0 { total_flops as f64 / total_memory as f64 } else { 0.0 };
738
739 let complexity_analysis = ComplexityAnalysis {
740 time_complexity: "O(n)".to_string(), space_complexity: "O(n)".to_string(), parallelization_potential: 0.7, sequential_dependencies: graph.metadata.max_depth,
744 };
745
746 Ok(FlopAnalysis {
747 total_flops,
748 flops_by_operation,
749 compute_hotspots,
750 arithmetic_intensity,
751 complexity_analysis,
752 })
753 }
754
755 fn detect_optimization_opportunities(
756 &self,
757 graph: &ComputationGraph,
758 ) -> Result<Vec<OptimizationOpportunity>> {
759 let mut opportunities = Vec::new();
760
761 opportunities.extend(self.detect_fusion_opportunities(graph)?);
763
764 opportunities.extend(self.detect_redundancy_opportunities(graph)?);
766
767 opportunities.extend(self.detect_memory_optimizations(graph)?);
769
770 Ok(opportunities)
771 }
772
773 fn detect_fusion_opportunities(
774 &self,
775 graph: &ComputationGraph,
776 ) -> Result<Vec<OptimizationOpportunity>> {
777 let mut opportunities = Vec::new();
778
779 for node in graph.nodes.values() {
781 if let OperationType::Add = node.operation_type {
782 let empty_deps = vec![];
783 let dependencies = graph.edges.get(&node.id).unwrap_or(&empty_deps);
784 for dep in dependencies {
785 if let Some(dep_node) = graph.nodes.get(dep) {
786 if let OperationType::MatMul = dep_node.operation_type {
787 opportunities.push(OptimizationOpportunity {
788 optimization_type: OptimizationType::OperationFusion,
789 description:
790 "Fuse MatMul and Add operations into a single GEMM operation"
791 .to_string(),
792 affected_nodes: vec![dep.clone(), node.id.clone()],
793 estimated_improvement: EstimatedImprovement {
794 speedup_factor: 1.2,
795 memory_reduction: 1024 * 1024, energy_savings: 0.1,
797 },
798 implementation_difficulty: 2,
799 priority: OptimizationPriority::Medium,
800 });
801 }
802 }
803 }
804 }
805 }
806
807 Ok(opportunities)
808 }
809
810 fn detect_redundancy_opportunities(
811 &self,
812 _graph: &ComputationGraph,
813 ) -> Result<Vec<OptimizationOpportunity>> {
814 Ok(vec![])
816 }
817
818 fn detect_memory_optimizations(
819 &self,
820 graph: &ComputationGraph,
821 ) -> Result<Vec<OptimizationOpportunity>> {
822 let mut opportunities = Vec::new();
823
824 for node in graph.nodes.values() {
826 if node.memory_usage > self.config.large_memory_threshold {
827 opportunities.push(OptimizationOpportunity {
828 optimization_type: OptimizationType::MemoryLayoutOptimization,
829 description: format!(
830 "Optimize memory layout for large operation: {}",
831 node.name
832 ),
833 affected_nodes: vec![node.id.clone()],
834 estimated_improvement: EstimatedImprovement {
835 speedup_factor: 1.1,
836 memory_reduction: node.memory_usage / 4, energy_savings: 0.05,
838 },
839 implementation_difficulty: 3,
840 priority: OptimizationPriority::Medium,
841 });
842 }
843 }
844
845 Ok(opportunities)
846 }
847
848 fn analyze_bottlenecks(&self, graph: &ComputationGraph) -> Result<BottleneckAnalysis> {
849 let mut bottleneck_nodes = Vec::new();
850 let mut parallelizable_nodes = Vec::new();
851
852 for node in graph.nodes.values() {
853 if let Some(exec_time) = node.execution_time_us {
854 if exec_time > self.config.bottleneck_threshold_us {
855 bottleneck_nodes.push(node.id.clone());
856 }
857 }
858
859 match node.operation_type {
861 OperationType::MatMul | OperationType::Conv2D | OperationType::Add => {
862 parallelizable_nodes.push(node.id.clone());
863 },
864 _ => {},
865 }
866 }
867
868 let critical_path_nodes = self.find_critical_path(graph)?;
869 let critical_path_time_us = critical_path_nodes
870 .iter()
871 .filter_map(|id| graph.nodes.get(id))
872 .filter_map(|node| node.execution_time_us)
873 .sum();
874
875 let scheduling_suggestions = vec![
876 "Consider parallel execution of independent operations".to_string(),
877 "Use asynchronous execution for I/O operations".to_string(),
878 "Implement pipeline parallelism for sequential operations".to_string(),
879 ];
880
881 Ok(BottleneckAnalysis {
882 bottleneck_nodes,
883 critical_path_nodes,
884 critical_path_time_us,
885 parallelizable_nodes,
886 scheduling_suggestions,
887 })
888 }
889
890 fn analyze_dataflow(&self, graph: &ComputationGraph) -> Result<DataFlowAnalysis> {
891 let mut data_dependencies = HashMap::new();
892 let mut live_variables = HashMap::new();
893 let mut variable_lifetimes = HashMap::new();
894
895 for (node_id, dependencies) in &graph.edges {
897 data_dependencies.insert(node_id.clone(), dependencies.clone());
898 live_variables.insert(node_id.clone(), dependencies.iter().cloned().collect());
899
900 for dep in dependencies {
902 if !variable_lifetimes.contains_key(dep) {
903 variable_lifetimes.insert(
904 dep.clone(),
905 VariableLifetime {
906 birth_node: dep.clone(),
907 death_node: node_id.clone(),
908 usage_nodes: vec![node_id.clone()],
909 memory_footprint: graph
910 .nodes
911 .get(dep)
912 .map(|n| n.memory_usage)
913 .unwrap_or(0),
914 },
915 );
916 } else {
917 let lifetime = variable_lifetimes
918 .get_mut(dep)
919 .expect("variable lifetime should exist for previously seen dependency");
920 lifetime.death_node = node_id.clone();
921 lifetime.usage_nodes.push(node_id.clone());
922 }
923 }
924 }
925
926 let memory_reuse_opportunities = vec![MemoryReuseOpportunity {
927 reusable_variables: vec!["var1".to_string(), "var2".to_string()],
928 memory_savings: 1024 * 1024, complexity: 2,
930 }];
931
932 Ok(DataFlowAnalysis {
933 data_dependencies,
934 live_variables,
935 variable_lifetimes,
936 memory_reuse_opportunities,
937 })
938 }
939
940 fn find_critical_path(&self, graph: &ComputationGraph) -> Result<Vec<String>> {
941 let mut path = Vec::new();
943 let mut current_depth = graph.metadata.max_depth;
944
945 while current_depth > 0 {
946 for node in graph.nodes.values() {
948 if node.depth == current_depth {
949 path.push(node.id.clone());
950 current_depth -= 1;
951 break;
952 }
953 }
954 current_depth = current_depth.saturating_sub(1);
955 }
956
957 path.reverse();
958 Ok(path)
959 }
960
961 fn calculate_statistics(&self, graph: &ComputationGraph) -> Result<GraphStatistics> {
962 let mut nodes_by_type: HashMap<OperationType, usize> = HashMap::new();
963 for node in graph.nodes.values() {
964 *nodes_by_type.entry(node.operation_type.clone()).or_insert(0) += 1;
965 }
966
967 let total_fan_in: usize = graph.edges.values().map(|deps| deps.len()).sum();
968 let total_fan_out = total_fan_in; let average_fan_in = total_fan_in as f64 / graph.nodes.len() as f64;
970 let average_fan_out = total_fan_out as f64 / graph.nodes.len() as f64;
971
972 Ok(GraphStatistics {
973 nodes_by_type,
974 average_fan_in,
975 average_fan_out,
976 diameter: graph.metadata.max_depth,
977 clustering_coefficient: 0.0, strongly_connected_components: graph.nodes.len(), })
980 }
981
982 fn generate_recommendations(&self, analysis: &GraphAnalysisResult) -> Result<Vec<String>> {
983 let mut recommendations = Vec::new();
984
985 if let Some(ref memory_analysis) = analysis.memory_analysis {
987 if memory_analysis.total_memory_usage > 1024 * 1024 * 1024 {
988 recommendations.push(
990 "Consider using gradient checkpointing to reduce memory usage".to_string(),
991 );
992 }
993 if memory_analysis.fragmentation_ratio > 0.2 {
994 recommendations
995 .push("Implement memory pooling to reduce fragmentation".to_string());
996 }
997 }
998
999 if let Some(ref flop_analysis) = analysis.flop_analysis {
1001 if flop_analysis.arithmetic_intensity < 1.0 {
1002 recommendations
1003 .push("Consider kernel fusion to improve arithmetic intensity".to_string());
1004 }
1005 if flop_analysis.complexity_analysis.parallelization_potential > 0.5 {
1006 recommendations.push(
1007 "Explore parallelization opportunities for compute-intensive operations"
1008 .to_string(),
1009 );
1010 }
1011 }
1012
1013 if analysis.optimization_opportunities.len() > 3 {
1015 recommendations.push(
1016 "Multiple optimization opportunities detected - prioritize by estimated impact"
1017 .to_string(),
1018 );
1019 }
1020
1021 if let Some(ref bottleneck_analysis) = analysis.bottleneck_analysis {
1023 if !bottleneck_analysis.bottleneck_nodes.is_empty() {
1024 recommendations.push(
1025 "Address bottleneck operations through optimization or parallelization"
1026 .to_string(),
1027 );
1028 }
1029 }
1030
1031 Ok(recommendations)
1032 }
1033
1034 fn get_node_color(&self, op_type: &OperationType) -> &'static str {
1035 match op_type {
1036 OperationType::MatMul | OperationType::Dot => "lightblue",
1037 OperationType::Add
1038 | OperationType::Subtract
1039 | OperationType::Multiply
1040 | OperationType::Divide => "lightgreen",
1041 OperationType::ReLU
1042 | OperationType::Sigmoid
1043 | OperationType::Tanh
1044 | OperationType::GELU => "orange",
1045 OperationType::LayerNorm | OperationType::BatchNorm | OperationType::RMSNorm => {
1046 "yellow"
1047 },
1048 OperationType::Conv1D | OperationType::Conv2D | OperationType::Conv3D => "lightcoral",
1049 OperationType::Attention | OperationType::MultiHeadAttention => "purple",
1050 OperationType::Embedding | OperationType::PositionalEmbedding => "pink",
1051 _ => "lightgray",
1052 }
1053 }
1054}
1055
1056impl fmt::Display for OperationType {
1057 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1058 match self {
1059 OperationType::Custom(name) => write!(f, "Custom({})", name),
1060 _ => write!(f, "{:?}", self),
1061 }
1062 }
1063}
1064
1065impl Default for ComputationGraphAnalyzer {
1066 fn default() -> Self {
1067 Self::new(GraphAnalysisConfig::default())
1068 }
1069}
1070
1071#[cfg(test)]
1072mod tests {
1073 use super::*;
1074
1075 #[test]
1076 fn test_computation_graph_creation() {
1077 let mut analyzer = ComputationGraphAnalyzer::default();
1078
1079 let operations = vec![
1080 (
1081 "input".to_string(),
1082 OperationType::Custom("Input".to_string()),
1083 vec![],
1084 ),
1085 (
1086 "linear1".to_string(),
1087 OperationType::MatMul,
1088 vec!["input".to_string()],
1089 ),
1090 (
1091 "relu1".to_string(),
1092 OperationType::ReLU,
1093 vec!["linear1".to_string()],
1094 ),
1095 (
1096 "linear2".to_string(),
1097 OperationType::MatMul,
1098 vec!["relu1".to_string()],
1099 ),
1100 (
1101 "output".to_string(),
1102 OperationType::Custom("Output".to_string()),
1103 vec!["linear2".to_string()],
1104 ),
1105 ];
1106
1107 let graph_id = analyzer
1108 .create_graph("test_model".to_string(), operations)
1109 .expect("operation failed in test");
1110 let analysis = analyzer.analyze_graph(graph_id).expect("operation failed in test");
1111
1112 assert_eq!(analysis.statistics.nodes_by_type.len(), 4); assert!(!analysis.critical_path.is_empty());
1114 }
1115
1116 #[test]
1117 fn test_optimization_detection() {
1118 let mut analyzer = ComputationGraphAnalyzer::default();
1119
1120 let operations = vec![
1121 (
1122 "input".to_string(),
1123 OperationType::Custom("Input".to_string()),
1124 vec![],
1125 ),
1126 (
1127 "matmul".to_string(),
1128 OperationType::MatMul,
1129 vec!["input".to_string()],
1130 ),
1131 (
1132 "add".to_string(),
1133 OperationType::Add,
1134 vec!["matmul".to_string()],
1135 ),
1136 ];
1137
1138 let graph_id = analyzer
1139 .create_graph("fusion_test".to_string(), operations)
1140 .expect("operation failed in test");
1141 let analysis = analyzer.analyze_graph(graph_id).expect("operation failed in test");
1142
1143 assert!(analysis
1144 .optimization_opportunities
1145 .iter()
1146 .any(|op| op.optimization_type == OptimizationType::OperationFusion));
1147 }
1148
1149 #[test]
1150 fn test_dot_export() {
1151 let mut analyzer = ComputationGraphAnalyzer::default();
1152
1153 let operations = vec![
1154 ("a".to_string(), OperationType::MatMul, vec![]),
1155 ("b".to_string(), OperationType::ReLU, vec!["a".to_string()]),
1156 ];
1157
1158 let graph_id = analyzer
1159 .create_graph("simple".to_string(), operations)
1160 .expect("operation failed in test");
1161 let dot = analyzer.export_to_dot(graph_id).expect("operation failed in test");
1162
1163 assert!(dot.contains("digraph"));
1164 assert!(dot.contains("MatMul"));
1165 assert!(dot.contains("ReLU"));
1166 }
1167}