1#![allow(unused_variables)] use crate::compiler::{ComputationGraph, DeviceType, GraphNode, HardwareTarget};
17use crate::errors::invalid_input;
18use crate::errors::TrustformersError;
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet, VecDeque};
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PerformanceAnalysis {
25 pub total_execution_time_ms: f64,
27 pub critical_path: Vec<usize>,
29 pub critical_path_length_ms: f64,
31 pub parallelizable_operations: Vec<Vec<usize>>,
33 pub bottlenecks: Vec<BottleneckInfo>,
35 pub load_balance_score: f64,
37 pub hardware_utilization: HardwareUtilization,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct BottleneckInfo {
44 pub node_id: usize,
45 pub operation_type: String,
46 pub execution_time_ms: f64,
47 pub memory_usage_mb: f64,
48 pub criticality_score: f64,
49 pub optimization_suggestions: Vec<String>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct HardwareUtilization {
55 pub compute_utilization: f64, pub memory_utilization: f64, pub memory_bandwidth_utilization: f64,
58 pub cache_hit_rate_prediction: f64,
59 pub parallel_efficiency: f64,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct MemoryAnalysis {
65 pub peak_memory_usage: u64,
67 pub memory_timeline: Vec<MemorySnapshot>,
69 pub allocation_patterns: Vec<AllocationPattern>,
71 pub reuse_opportunities: Vec<ReuseOpportunity>,
73 pub fragmentation_analysis: FragmentationAnalysis,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct MemorySnapshot {
80 pub operation_id: usize,
81 pub allocated_memory: u64,
82 pub active_tensors: Vec<TensorInfo>,
83 pub memory_pressure: f64,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct TensorInfo {
89 pub id: usize,
90 pub shape: Vec<usize>,
91 pub dtype: String,
92 pub size_bytes: u64,
93 pub lifetime_start: usize,
94 pub lifetime_end: usize,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct AllocationPattern {
100 pub pattern_type: AllocationType,
101 pub frequency: usize,
102 pub total_size: u64,
103 pub optimization_potential: f64,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub enum AllocationType {
108 Sequential,
109 Scattered,
110 Temporary,
111 LongLived,
112 Reusable,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct ReuseOpportunity {
118 pub tensor_id: usize,
119 pub reusable_with: Vec<usize>,
120 pub memory_savings: u64,
121 pub implementation_complexity: ComplexityLevel,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub enum ComplexityLevel {
126 Low,
127 Medium,
128 High,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct FragmentationAnalysis {
134 pub fragmentation_ratio: f64,
135 pub largest_free_block: u64,
136 pub allocation_efficiency: f64,
137 pub defragmentation_potential: f64,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct DependencyAnalysis {
143 pub topological_order: Vec<usize>,
145 pub connected_components: Vec<Vec<usize>>,
147 pub data_dependencies: Vec<Dependency>,
149 pub loop_analysis: LoopAnalysis,
151 pub parallelization: ParallelizationAnalysis,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct Dependency {
158 pub from: usize,
159 pub to: usize,
160 pub dependency_type: DependencyType,
161 pub data_size: u64,
162 pub latency_impact: f64,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub enum DependencyType {
167 DataFlow,
168 Control,
169 Memory,
170 Synchronization,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct LoopAnalysis {
176 pub detected_loops: Vec<LoopInfo>,
177 pub loop_carried_dependencies: Vec<Dependency>,
178 pub vectorization_opportunities: Vec<VectorizationOpportunity>,
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct LoopInfo {
184 pub loop_id: usize,
185 pub operations: Vec<usize>,
186 pub iteration_count: Option<usize>,
187 pub loop_type: LoopType,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub enum LoopType {
192 CountBased,
193 DataDependent,
194 Infinite,
195 Unknown,
196}
197
198#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct VectorizationOpportunity {
201 pub operations: Vec<usize>,
202 pub vector_width: usize,
203 pub performance_gain: f64,
204 pub instruction_set: String,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct ParallelizationAnalysis {
210 pub parallel_regions: Vec<ParallelRegion>,
211 pub synchronization_points: Vec<usize>,
212 pub load_balance_analysis: LoadBalanceAnalysis,
213 pub communication_analysis: CommunicationAnalysis,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct ParallelRegion {
219 pub operations: Vec<usize>,
220 pub parallelism_type: ParallelismType,
221 pub estimated_speedup: f64,
222 pub resource_requirements: ResourceRequirements,
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub enum ParallelismType {
227 DataParallel,
228 TaskParallel,
229 Pipeline,
230 Mixed,
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct ResourceRequirements {
236 pub min_threads: usize,
237 pub optimal_threads: usize,
238 pub memory_per_thread: u64,
239 pub communication_bandwidth: f64,
240}
241
242#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct LoadBalanceAnalysis {
245 pub balance_score: f64,
246 pub work_distribution: Vec<f64>,
247 pub synchronization_overhead: f64,
248 pub recommendations: Vec<String>,
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct CommunicationAnalysis {
254 pub communication_volume: u64,
255 pub communication_patterns: Vec<CommunicationPattern>,
256 pub network_utilization: f64,
257 pub latency_sensitivity: f64,
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct CommunicationPattern {
262 pub pattern_type: CommunicationType,
263 pub data_size: u64,
264 pub frequency: usize,
265 pub optimization_potential: f64,
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub enum CommunicationType {
270 AllToAll,
271 AllReduce,
272 PointToPoint,
273 Broadcast,
274 Gather,
275 Scatter,
276}
277
278pub struct GraphAnalyzer {
280 hardware_target: HardwareTarget,
281 #[allow(dead_code)]
282 analysis_cache: HashMap<String, AnalysisResult>,
283}
284
285#[derive(Debug, Clone)]
287pub enum AnalysisResult {
288 Performance(PerformanceAnalysis),
289 Memory(MemoryAnalysis),
290 Dependency(DependencyAnalysis),
291}
292
293impl GraphAnalyzer {
294 pub fn new(hardware_target: HardwareTarget) -> Self {
296 Self {
297 hardware_target,
298 analysis_cache: HashMap::new(),
299 }
300 }
301
302 pub fn analyze_performance(
304 &mut self,
305 graph: &ComputationGraph,
306 ) -> Result<PerformanceAnalysis, TrustformersError> {
307 let critical_path = self.find_critical_path(graph)?;
309 let critical_path_length = self.calculate_path_length(&critical_path, graph)?;
310
311 let bottlenecks = self.detect_bottlenecks(graph)?;
313
314 let parallelizable_ops = self.find_parallelizable_operations(graph)?;
316
317 let load_balance_score = self.calculate_load_balance_score(graph)?;
319
320 let hardware_utilization = self.predict_hardware_utilization(graph)?;
322
323 let total_execution_time =
324 graph.nodes.iter().map(|node| self.estimate_execution_time(node)).sum();
325
326 Ok(PerformanceAnalysis {
327 total_execution_time_ms: total_execution_time,
328 critical_path,
329 critical_path_length_ms: critical_path_length,
330 parallelizable_operations: parallelizable_ops,
331 bottlenecks,
332 load_balance_score,
333 hardware_utilization,
334 })
335 }
336
337 pub fn analyze_memory(
339 &mut self,
340 graph: &ComputationGraph,
341 ) -> Result<MemoryAnalysis, TrustformersError> {
342 let memory_timeline = self.simulate_memory_usage(graph)?;
343 let peak_memory = memory_timeline
344 .iter()
345 .map(|snapshot| snapshot.allocated_memory)
346 .max()
347 .unwrap_or(0);
348
349 let allocation_patterns = self.analyze_allocation_patterns(graph)?;
350 let reuse_opportunities = self.find_reuse_opportunities(graph)?;
351 let fragmentation_analysis = self.analyze_fragmentation(graph)?;
352
353 Ok(MemoryAnalysis {
354 peak_memory_usage: peak_memory,
355 memory_timeline,
356 allocation_patterns,
357 reuse_opportunities,
358 fragmentation_analysis,
359 })
360 }
361
362 pub fn analyze_dependencies(
364 &mut self,
365 graph: &ComputationGraph,
366 ) -> Result<DependencyAnalysis, TrustformersError> {
367 let topological_order = self.topological_sort(graph)?;
368 let connected_components = self.find_connected_components(graph)?;
369 let data_dependencies = self.analyze_data_dependencies(graph)?;
370 let loop_analysis = self.analyze_loops(graph)?;
371 let parallelization = self.analyze_parallelization(graph)?;
372
373 Ok(DependencyAnalysis {
374 topological_order,
375 connected_components,
376 data_dependencies,
377 loop_analysis,
378 parallelization,
379 })
380 }
381
382 fn find_critical_path(
384 &self,
385 graph: &ComputationGraph,
386 ) -> Result<Vec<usize>, TrustformersError> {
387 let mut longest_path = HashMap::new();
388 let mut predecessors = HashMap::new();
389
390 for node in &graph.nodes {
392 longest_path.insert(node.id, 0.0);
393 }
394
395 let topo_order = self.topological_sort(graph)?;
397
398 for &node_id in &topo_order {
399 let node_time = self.estimate_execution_time(&graph.nodes[node_id]);
400
401 for edge in &graph.edges {
402 if edge.from != node_id {
403 continue;
404 }
405 let new_distance = longest_path[&node_id] + node_time;
406 if new_distance > longest_path[&edge.to] {
407 longest_path.insert(edge.to, new_distance);
408 predecessors.insert(edge.to, node_id);
409 }
410 }
411 }
412
413 let end_node = longest_path
415 .iter()
416 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
417 .map(|(node_id, _)| *node_id)
418 .unwrap_or(0);
419
420 let mut path = Vec::new();
422 let mut current = end_node;
423
424 while let Some(&predecessor) = predecessors.get(¤t) {
425 path.push(current);
426 current = predecessor;
427 }
428 path.push(current);
429 path.reverse();
430
431 Ok(path)
432 }
433
434 fn calculate_path_length(
436 &self,
437 path: &[usize],
438 graph: &ComputationGraph,
439 ) -> Result<f64, TrustformersError> {
440 let total_time = path
441 .iter()
442 .map(|&node_id| {
443 if let Some(node) = graph.get_node(node_id) {
444 self.estimate_execution_time(node)
445 } else {
446 0.0
447 }
448 })
449 .sum();
450
451 Ok(total_time)
452 }
453
454 fn estimate_execution_time(&self, node: &GraphNode) -> f64 {
456 let base_time = match node.op_type.as_str() {
458 "MatMul" => {
459 let flops = node.compute_cost;
461 match self.hardware_target.device_type {
462 DeviceType::GPU => flops / 10e12, DeviceType::CPU => flops / 1e12, _ => flops / 1e9, }
466 },
467 "Conv2D" => node.compute_cost / 5e12, "Add" | "Mul" | "Sub" | "Div" => node.compute_cost / 1e13, "ReLU" | "Sigmoid" | "Tanh" => node.compute_cost / 1e12,
470 _ => node.compute_cost / 1e9, };
472
473 let memory_time = node.memory_cost / self.hardware_target.memory_bandwidth;
475
476 (base_time + memory_time) * 1000.0 }
478
479 fn detect_bottlenecks(
481 &self,
482 graph: &ComputationGraph,
483 ) -> Result<Vec<BottleneckInfo>, TrustformersError> {
484 let mut bottlenecks = Vec::new();
485
486 let total_time: f64 =
487 graph.nodes.iter().map(|node| self.estimate_execution_time(node)).sum();
488
489 for node in &graph.nodes {
490 let execution_time = self.estimate_execution_time(node);
491 let time_percentage = execution_time / total_time;
492
493 if time_percentage > 0.1 {
495 let memory_usage = node.memory_cost / (1024.0 * 1024.0); let criticality_score = time_percentage * 100.0;
497
498 let suggestions = self.generate_optimization_suggestions(node);
499
500 bottlenecks.push(BottleneckInfo {
501 node_id: node.id,
502 operation_type: node.op_type.clone(),
503 execution_time_ms: execution_time,
504 memory_usage_mb: memory_usage,
505 criticality_score,
506 optimization_suggestions: suggestions,
507 });
508 }
509 }
510
511 bottlenecks.sort_by(|a, b| {
513 b.criticality_score
514 .partial_cmp(&a.criticality_score)
515 .expect("Partial comparison failed")
516 });
517
518 Ok(bottlenecks)
519 }
520
521 fn generate_optimization_suggestions(&self, node: &GraphNode) -> Vec<String> {
523 let mut suggestions = Vec::new();
524
525 match node.op_type.as_str() {
526 "MatMul" => {
527 suggestions.push("Consider using optimized BLAS libraries".to_string());
528 suggestions.push("Try different matrix multiplication algorithms".to_string());
529 suggestions
530 .push("Consider batch processing for multiple small matrices".to_string());
531 },
532 "Conv2D" => {
533 suggestions.push("Use optimized convolution libraries (cuDNN, oneDNN)".to_string());
534 suggestions
535 .push("Consider different convolution algorithms (Winograd, FFT)".to_string());
536 suggestions.push("Try different data layouts (NCHW vs NHWC)".to_string());
537 },
538 "Attention" => {
539 suggestions.push(
540 "Use FlashAttention or similar memory-efficient implementations".to_string(),
541 );
542 suggestions.push("Consider attention sparsity patterns".to_string());
543 suggestions.push("Try different attention approximations".to_string());
544 },
545 _ => {
546 suggestions.push("Profile the operation to understand bottlenecks".to_string());
547 suggestions
548 .push("Consider operation fusion with neighboring operations".to_string());
549 },
550 }
551
552 suggestions
553 }
554
555 fn find_parallelizable_operations(
557 &self,
558 graph: &ComputationGraph,
559 ) -> Result<Vec<Vec<usize>>, TrustformersError> {
560 let mut parallel_groups = Vec::new();
561 let mut visited = HashSet::new();
562
563 for (i, node1) in graph.nodes.iter().enumerate() {
565 if visited.contains(&i) {
566 continue;
567 }
568
569 let mut group = vec![i];
570 visited.insert(i);
571
572 for (j, node2) in graph.nodes.iter().enumerate() {
573 if i == j || visited.contains(&j) {
574 continue;
575 }
576 if self.has_dependency_path(i, j, graph) || self.has_dependency_path(j, i, graph) {
578 continue;
579 }
580 group.push(j);
581 visited.insert(j);
582 }
583
584 if group.len() > 1 {
585 parallel_groups.push(group);
586 }
587 }
588
589 Ok(parallel_groups)
590 }
591
592 fn has_dependency_path(&self, from: usize, to: usize, graph: &ComputationGraph) -> bool {
594 let mut visited = HashSet::new();
595 let mut queue = VecDeque::new();
596
597 queue.push_back(from);
598 visited.insert(from);
599
600 while let Some(current) = queue.pop_front() {
601 if current == to {
602 return true;
603 }
604
605 for edge in &graph.edges {
606 if edge.from == current && !visited.contains(&edge.to) {
607 visited.insert(edge.to);
608 queue.push_back(edge.to);
609 }
610 }
611 }
612
613 false
614 }
615
616 fn calculate_load_balance_score(
618 &self,
619 graph: &ComputationGraph,
620 ) -> Result<f64, TrustformersError> {
621 let execution_times: Vec<f64> =
622 graph.nodes.iter().map(|node| self.estimate_execution_time(node)).collect();
623
624 if execution_times.is_empty() {
625 return Ok(1.0);
626 }
627
628 let mean_time: f64 = execution_times.iter().sum::<f64>() / execution_times.len() as f64;
629 let variance: f64 =
630 execution_times.iter().map(|&time| (time - mean_time).powi(2)).sum::<f64>()
631 / execution_times.len() as f64;
632
633 let coefficient_of_variation = variance.sqrt() / mean_time.max(1e-10);
634
635 Ok((1.0 / (1.0 + coefficient_of_variation)).min(1.0))
637 }
638
639 fn predict_hardware_utilization(
641 &self,
642 graph: &ComputationGraph,
643 ) -> Result<HardwareUtilization, TrustformersError> {
644 let total_compute = graph.total_compute_cost();
645 let total_memory = graph.total_memory_cost();
646
647 let compute_intensive_ops = graph
649 .nodes
650 .iter()
651 .filter(|node| matches!(node.op_type.as_str(), "MatMul" | "Conv2D" | "Attention"))
652 .count();
653
654 let compute_utilization =
655 (compute_intensive_ops as f64 / graph.nodes.len().max(1) as f64) * 0.8;
656
657 let estimated_memory = total_memory;
659 let available_memory = match self.hardware_target.device_type {
660 DeviceType::GPU => 16e9, DeviceType::CPU => 64e9, _ => 8e9, };
664
665 let memory_utilization = (estimated_memory / available_memory).min(1.0);
666
667 let memory_bandwidth_utilization =
669 (total_memory / 1e9) / self.hardware_target.memory_bandwidth;
670
671 let cache_hit_rate_prediction = 0.8; let parallelizable_ops = self.find_parallelizable_operations(graph)?.len();
676 let parallel_efficiency =
677 (parallelizable_ops as f64 / graph.nodes.len().max(1) as f64) * 0.9;
678
679 Ok(HardwareUtilization {
680 compute_utilization,
681 memory_utilization,
682 memory_bandwidth_utilization,
683 cache_hit_rate_prediction,
684 parallel_efficiency,
685 })
686 }
687
688 fn simulate_memory_usage(
690 &self,
691 graph: &ComputationGraph,
692 ) -> Result<Vec<MemorySnapshot>, TrustformersError> {
693 let mut snapshots = Vec::new();
694 let mut active_tensors = HashMap::new();
695 let mut total_memory = 0u64;
696
697 let topo_order = self.topological_sort(graph)?;
698
699 for &node_id in &topo_order {
700 if let Some(node) = graph.get_node(node_id) {
701 for (i, shape) in node.output_shapes.iter().enumerate() {
703 let tensor_size = self.calculate_tensor_size(shape, "f32");
704 let tensor_info = TensorInfo {
705 id: node_id * 100 + i, shape: shape.clone(),
707 dtype: "f32".to_string(),
708 size_bytes: tensor_size,
709 lifetime_start: node_id,
710 lifetime_end: node_id + 10, };
712
713 active_tensors.insert(tensor_info.id, tensor_info);
714 total_memory += tensor_size;
715 }
716
717 let memory_pressure = total_memory as f64 / 16e9; let snapshot = MemorySnapshot {
721 operation_id: node_id,
722 allocated_memory: total_memory,
723 active_tensors: active_tensors.values().cloned().collect(),
724 memory_pressure,
725 };
726
727 snapshots.push(snapshot);
728
729 active_tensors.retain(|_, tensor| tensor.lifetime_end > node_id);
731 total_memory = active_tensors.values().map(|t| t.size_bytes).sum();
732 }
733 }
734
735 Ok(snapshots)
736 }
737
738 fn calculate_tensor_size(&self, shape: &[usize], dtype: &str) -> u64 {
740 let element_size = match dtype {
741 "f32" | "i32" => 4,
742 "f16" | "i16" => 2,
743 "f64" | "i64" => 8,
744 "i8" | "u8" => 1,
745 _ => 4, };
747
748 let elements: usize = shape.iter().product();
749 (elements * element_size) as u64
750 }
751
752 fn topological_sort(&self, graph: &ComputationGraph) -> Result<Vec<usize>, TrustformersError> {
754 let mut in_degree = vec![0; graph.nodes.len()];
755 let mut adj_list = vec![Vec::new(); graph.nodes.len()];
756
757 for edge in &graph.edges {
759 if edge.from < graph.nodes.len() && edge.to < graph.nodes.len() {
760 adj_list[edge.from].push(edge.to);
761 in_degree[edge.to] += 1;
762 }
763 }
764
765 let mut queue = VecDeque::new();
767 let mut result = Vec::new();
768
769 for (i, °ree) in in_degree.iter().enumerate() {
771 if degree == 0 {
772 queue.push_back(i);
773 }
774 }
775
776 while let Some(node) = queue.pop_front() {
777 result.push(node);
778
779 for &neighbor in &adj_list[node] {
780 in_degree[neighbor] -= 1;
781 if in_degree[neighbor] == 0 {
782 queue.push_back(neighbor);
783 }
784 }
785 }
786
787 if result.len() != graph.nodes.len() {
788 return Err(invalid_input("Graph contains cycles"));
789 }
790
791 Ok(result)
792 }
793
794 fn find_connected_components(
796 &self,
797 _graph: &ComputationGraph,
798 ) -> Result<Vec<Vec<usize>>, TrustformersError> {
799 Ok(Vec::new()) }
801
802 fn analyze_data_dependencies(
803 &self,
804 _graph: &ComputationGraph,
805 ) -> Result<Vec<Dependency>, TrustformersError> {
806 Ok(Vec::new()) }
808
809 fn analyze_loops(&self, _graph: &ComputationGraph) -> Result<LoopAnalysis, TrustformersError> {
810 Ok(LoopAnalysis {
811 detected_loops: Vec::new(),
812 loop_carried_dependencies: Vec::new(),
813 vectorization_opportunities: Vec::new(),
814 })
815 }
816
817 fn analyze_parallelization(
818 &self,
819 _graph: &ComputationGraph,
820 ) -> Result<ParallelizationAnalysis, TrustformersError> {
821 Ok(ParallelizationAnalysis {
822 parallel_regions: Vec::new(),
823 synchronization_points: Vec::new(),
824 load_balance_analysis: LoadBalanceAnalysis {
825 balance_score: 0.8,
826 work_distribution: Vec::new(),
827 synchronization_overhead: 0.1,
828 recommendations: Vec::new(),
829 },
830 communication_analysis: CommunicationAnalysis {
831 communication_volume: 0,
832 communication_patterns: Vec::new(),
833 network_utilization: 0.5,
834 latency_sensitivity: 0.3,
835 },
836 })
837 }
838
839 fn analyze_allocation_patterns(
840 &self,
841 _graph: &ComputationGraph,
842 ) -> Result<Vec<AllocationPattern>, TrustformersError> {
843 Ok(Vec::new()) }
845
846 fn find_reuse_opportunities(
847 &self,
848 _graph: &ComputationGraph,
849 ) -> Result<Vec<ReuseOpportunity>, TrustformersError> {
850 Ok(Vec::new()) }
852
853 fn analyze_fragmentation(
854 &self,
855 _graph: &ComputationGraph,
856 ) -> Result<FragmentationAnalysis, TrustformersError> {
857 Ok(FragmentationAnalysis {
858 fragmentation_ratio: 0.1,
859 largest_free_block: 1024 * 1024 * 1024, allocation_efficiency: 0.9,
861 defragmentation_potential: 0.05,
862 })
863 }
864}
865
866#[cfg(test)]
867mod tests {
868 use super::*;
869 use crate::compiler::{ComputationGraph, GraphNode, HardwareTarget};
870
871 fn create_test_graph() -> ComputationGraph {
872 let mut graph = ComputationGraph::new();
873
874 let node1 = GraphNode {
875 id: 0,
876 op_type: "MatMul".to_string(),
877 attributes: HashMap::new(),
878 input_shapes: vec![vec![128, 256], vec![256, 512]],
879 output_shapes: vec![vec![128, 512]],
880 compute_cost: 100.0,
881 memory_cost: 50.0,
882 };
883
884 graph.add_node(node1);
885 graph
886 }
887
888 #[test]
889 fn test_graph_analyzer_creation() {
890 let hardware = HardwareTarget::default();
891 let analyzer = GraphAnalyzer::new(hardware);
892 assert_eq!(analyzer.analysis_cache.len(), 0);
893 }
894
895 #[test]
896 fn test_performance_analysis() {
897 let hardware = HardwareTarget::default();
898 let mut analyzer = GraphAnalyzer::new(hardware);
899 let graph = create_test_graph();
900
901 let result = analyzer.analyze_performance(&graph);
902 assert!(result.is_ok());
903
904 let analysis = result.expect("operation failed in test");
905 assert!(analysis.total_execution_time_ms >= 0.0);
906 }
907
908 #[test]
909 fn test_memory_analysis() {
910 let hardware = HardwareTarget::default();
911 let mut analyzer = GraphAnalyzer::new(hardware);
912 let graph = create_test_graph();
913
914 let result = analyzer.analyze_memory(&graph);
915 assert!(result.is_ok());
916 }
917
918 #[test]
919 fn test_dependency_analysis() {
920 let hardware = HardwareTarget::default();
921 let mut analyzer = GraphAnalyzer::new(hardware);
922 let graph = create_test_graph();
923
924 let result = analyzer.analyze_dependencies(&graph);
925 assert!(result.is_ok());
926 }
927
928 #[test]
929 fn test_critical_path_analysis() {
930 let hardware = HardwareTarget::default();
931 let analyzer = GraphAnalyzer::new(hardware);
932 let graph = create_test_graph();
933
934 let result = analyzer.find_critical_path(&graph);
935 assert!(result.is_ok());
936 assert!(!result.expect("operation failed in test").is_empty());
937 }
938
939 #[test]
940 fn test_topological_sort() {
941 let hardware = HardwareTarget::default();
942 let analyzer = GraphAnalyzer::new(hardware);
943 let graph = create_test_graph();
944
945 let result = analyzer.topological_sort(&graph);
946 assert!(result.is_ok());
947 assert_eq!(
948 result.expect("operation failed in test").len(),
949 graph.nodes.len()
950 );
951 }
952}