1use crate::{FxGraph, Node, TorshResult};
10use petgraph::graph::NodeIndex;
11use petgraph::visit::EdgeRef;
12use scirs2_core::parallel_ops::*;
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, HashSet};
16use std::sync::{Arc, Mutex, RwLock};
17use std::time::{Duration, Instant};
18
19pub struct ParallelTraversal {
21 graph: Arc<FxGraph>,
22 thread_pool_size: Option<usize>,
23}
24
25impl ParallelTraversal {
26 pub fn new(graph: Arc<FxGraph>) -> Self {
28 Self {
29 graph,
30 thread_pool_size: None,
31 }
32 }
33
34 pub fn with_thread_pool_size(mut self, size: usize) -> Self {
36 self.thread_pool_size = Some(size);
37 self
38 }
39
40 pub fn parallel_topological_traverse<F>(&self, visitor: F) -> TorshResult<()>
42 where
43 F: Fn(NodeIndex, &Node) + Send + Sync,
44 {
45 let dependencies = self.build_dependency_map();
47 let visited = HashSet::new();
48 let mut ready_nodes = Vec::new();
49
50 for (idx, _) in self.graph.nodes() {
52 if dependencies.get(&idx).map_or(true, |deps| deps.is_empty()) {
53 ready_nodes.push(idx);
54 }
55 }
56
57 let visited = Arc::new(Mutex::new(visited));
58 let dependencies = Arc::new(dependencies);
59
60 while !ready_nodes.is_empty() {
61 ready_nodes.par_iter().for_each(|&idx| {
63 if let Some(node) = self.graph.get_node(idx) {
64 visitor(idx, node);
65 visited
66 .lock()
67 .expect("lock should not be poisoned")
68 .insert(idx);
69 }
70 });
71
72 ready_nodes.clear();
74
75 for (idx, deps) in dependencies.iter() {
76 let visited_guard = visited.lock().expect("lock should not be poisoned");
77 if !visited_guard.contains(idx) {
78 if deps.iter().all(|dep| visited_guard.contains(dep)) {
79 ready_nodes.push(*idx);
80 }
81 }
82 drop(visited_guard);
84 }
85 }
86
87 Ok(())
88 }
89
90 pub fn parallel_dfs<F>(&self, start_nodes: Vec<NodeIndex>, visitor: F) -> TorshResult<()>
92 where
93 F: Fn(NodeIndex, &Node) + Send + Sync,
94 {
95 let visited = Arc::new(Mutex::new(HashSet::new()));
96 let visitor = Arc::new(visitor);
97
98 start_nodes.into_par_iter().for_each(|start| {
99 self.dfs_worker(start, visited.clone(), visitor.clone());
100 });
101
102 Ok(())
103 }
104
105 fn build_dependency_map(&self) -> HashMap<NodeIndex, Vec<NodeIndex>> {
107 let mut dependencies = HashMap::new();
108
109 for (idx, _) in self.graph.nodes() {
110 dependencies.insert(idx, Vec::new());
111 }
112
113 for edge_ref in self.graph.graph.edge_references() {
115 use petgraph::visit::EdgeRef;
116 let target = edge_ref.target();
117 let source = edge_ref.source();
118 dependencies
119 .get_mut(&target)
120 .expect("target node should exist in dependencies map")
121 .push(source);
122 }
123
124 dependencies
125 }
126
127 fn dfs_worker<F>(
129 &self,
130 node: NodeIndex,
131 visited: Arc<Mutex<HashSet<NodeIndex>>>,
132 visitor: Arc<F>,
133 ) where
134 F: Fn(NodeIndex, &Node) + Send + Sync,
135 {
136 let mut stack = vec![node];
137
138 while let Some(current) = stack.pop() {
139 let already_visited = {
140 let mut v = visited.lock().expect("lock should not be poisoned");
141 if v.contains(¤t) {
142 true
143 } else {
144 v.insert(current);
145 false
146 }
147 };
148
149 if already_visited {
150 continue;
151 }
152
153 if let Some(node_data) = self.graph.get_node(current) {
154 visitor(current, node_data);
155 }
156
157 for edge in self.graph.graph.edges(current) {
159 stack.push(edge.target());
160 }
161 }
162 }
163}
164
165#[derive(Debug)]
167pub struct GraphCache {
168 operation_cache: RwLock<HashMap<String, CachedResult>>,
169 subgraph_cache: RwLock<HashMap<String, Arc<FxGraph>>>,
170 cache_stats: Arc<Mutex<CacheStatistics>>,
171 max_cache_size: usize,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct CachedResult {
176 pub result: String,
177 pub computation_time: Duration,
178 pub access_count: u64,
179 pub last_accessed: std::time::SystemTime,
180}
181
182#[derive(Debug, Default, Clone)]
183pub struct CacheStatistics {
184 pub hits: u64,
185 pub misses: u64,
186 pub evictions: u64,
187 pub total_computation_time_saved: Duration,
188}
189
190impl GraphCache {
191 pub fn new(max_cache_size: usize) -> Self {
193 Self {
194 operation_cache: RwLock::new(HashMap::new()),
195 subgraph_cache: RwLock::new(HashMap::new()),
196 cache_stats: Arc::new(Mutex::new(CacheStatistics::default())),
197 max_cache_size,
198 }
199 }
200
201 pub fn get_operation(&self, key: &str) -> Option<CachedResult> {
203 let mut cache = self
204 .operation_cache
205 .write()
206 .expect("lock should not be poisoned");
207 if let Some(result) = cache.get_mut(key) {
208 result.access_count += 1;
209 result.last_accessed = std::time::SystemTime::now();
210 self.cache_stats
211 .lock()
212 .expect("lock should not be poisoned")
213 .hits += 1;
214 Some(result.clone())
215 } else {
216 self.cache_stats
217 .lock()
218 .expect("lock should not be poisoned")
219 .misses += 1;
220 None
221 }
222 }
223
224 pub fn cache_operation(&self, key: String, result: String, computation_time: Duration) {
226 let mut cache = self
227 .operation_cache
228 .write()
229 .expect("lock should not be poisoned");
230
231 if cache.len() >= self.max_cache_size {
233 self.evict_lru_operation(&mut cache);
234 }
235
236 let cached_result = CachedResult {
237 result,
238 computation_time,
239 access_count: 1,
240 last_accessed: std::time::SystemTime::now(),
241 };
242
243 cache.insert(key, cached_result);
244 }
245
246 pub fn get_subgraph(&self, key: &str) -> Option<Arc<FxGraph>> {
248 let cache = self
249 .subgraph_cache
250 .read()
251 .expect("lock should not be poisoned");
252 if let Some(graph) = cache.get(key) {
253 self.cache_stats
254 .lock()
255 .expect("lock should not be poisoned")
256 .hits += 1;
257 Some(graph.clone())
258 } else {
259 self.cache_stats
260 .lock()
261 .expect("lock should not be poisoned")
262 .misses += 1;
263 None
264 }
265 }
266
267 pub fn cache_subgraph(&self, key: String, graph: Arc<FxGraph>) {
269 let mut cache = self
270 .subgraph_cache
271 .write()
272 .expect("lock should not be poisoned");
273
274 if cache.len() >= self.max_cache_size {
276 self.evict_lru_subgraph(&mut cache);
277 }
278
279 cache.insert(key, graph);
280 }
281
282 pub fn statistics(&self) -> CacheStatistics {
284 self.cache_stats
285 .lock()
286 .expect("lock should not be poisoned")
287 .clone()
288 }
289
290 pub fn clear(&self) {
292 self.operation_cache
293 .write()
294 .expect("lock should not be poisoned")
295 .clear();
296 self.subgraph_cache
297 .write()
298 .expect("lock should not be poisoned")
299 .clear();
300 *self
301 .cache_stats
302 .lock()
303 .expect("lock should not be poisoned") = CacheStatistics::default();
304 }
305
306 fn evict_lru_operation(&self, cache: &mut HashMap<String, CachedResult>) {
308 if let Some(lru_key) = cache
309 .iter()
310 .min_by_key(|(_, result)| result.last_accessed)
311 .map(|(key, _)| key.clone())
312 {
313 cache.remove(&lru_key);
314 self.cache_stats
315 .lock()
316 .expect("lock should not be poisoned")
317 .evictions += 1;
318 }
319 }
320
321 fn evict_lru_subgraph(&self, cache: &mut HashMap<String, Arc<FxGraph>>) {
323 if let Some(key) = cache.keys().next().cloned() {
324 cache.remove(&key);
325 self.cache_stats
326 .lock()
327 .expect("lock should not be poisoned")
328 .evictions += 1;
329 }
330 }
331}
332
333pub struct GraphCompression;
335
336impl GraphCompression {
337 pub fn deduplicate_operations(graph: &FxGraph) -> TorshResult<FxGraph> {
339 let mut compressed_graph = FxGraph::new();
340 let mut operation_map: HashMap<String, NodeIndex> = HashMap::new();
341 let mut node_mapping: HashMap<NodeIndex, NodeIndex> = HashMap::new();
342
343 for (old_idx, node) in graph.nodes() {
345 let operation_key = Self::operation_key(node);
346
347 if let Some(&existing_idx) = operation_map.get(&operation_key) {
348 node_mapping.insert(old_idx, existing_idx);
350 } else {
351 let new_idx = compressed_graph.graph.add_node(node.clone());
353 operation_map.insert(operation_key, new_idx);
354 node_mapping.insert(old_idx, new_idx);
355 }
356 }
357
358 for edge_ref in graph.graph.edge_references() {
360 use petgraph::visit::EdgeRef;
361 let old_source = edge_ref.source();
362 let old_target = edge_ref.target();
363
364 if let (Some(&new_source), Some(&new_target)) =
365 (node_mapping.get(&old_source), node_mapping.get(&old_target))
366 {
367 if !compressed_graph
369 .graph
370 .edges_connecting(new_source, new_target)
371 .next()
372 .is_some()
373 {
374 compressed_graph.graph.add_edge(
375 new_source,
376 new_target,
377 edge_ref.weight().clone(),
378 );
379 }
380 }
381 }
382
383 compressed_graph.inputs = graph
385 .inputs()
386 .iter()
387 .filter_map(|&idx| node_mapping.get(&idx).copied())
388 .collect();
389 compressed_graph.outputs = graph
390 .outputs()
391 .iter()
392 .filter_map(|&idx| node_mapping.get(&idx).copied())
393 .collect();
394
395 Ok(compressed_graph)
396 }
397
398 fn operation_key(node: &Node) -> String {
400 match node {
401 Node::Input(name) => format!("input:{name}"),
402 Node::Call(op, args) => {
403 let args_str = args.join(",");
404 format!("call:{op}:{args_str}")
405 }
406 Node::Output => "output".into(),
407 Node::Conditional {
408 condition,
409 then_branch,
410 else_branch,
411 } => {
412 format!(
413 "conditional:{}:{}:{}",
414 condition,
415 then_branch.join(","),
416 else_branch.join(",")
417 )
418 }
419 Node::Loop {
420 condition,
421 body,
422 loop_vars,
423 } => {
424 format!(
425 "loop:{}:{}:{}",
426 condition,
427 body.join(","),
428 loop_vars.join(",")
429 )
430 }
431 Node::Merge { inputs } => {
432 let inputs_str = inputs.join(",");
433 format!("merge:{inputs_str}")
434 }
435 Node::GetAttr { target, attr } => format!("getattr:{target}:{attr}"),
436 }
437 }
438
439 pub fn remove_redundant_nodes(graph: &FxGraph) -> TorshResult<FxGraph> {
441 let mut compressed_graph = FxGraph::new();
442 let mut node_mapping: HashMap<NodeIndex, Option<NodeIndex>> = HashMap::new();
443
444 for (old_idx, node) in graph.nodes() {
446 if Self::is_redundant_node(node) {
447 node_mapping.insert(old_idx, None); } else {
449 let new_idx = compressed_graph.graph.add_node(node.clone());
450 node_mapping.insert(old_idx, Some(new_idx));
451 }
452 }
453
454 for edge_ref in graph.graph.edge_references() {
456 use petgraph::visit::EdgeRef;
457 let old_source = edge_ref.source();
458 let old_target = edge_ref.target();
459
460 let new_source = Self::find_actual_node(old_source, &node_mapping, graph);
462 let new_target = Self::find_actual_node(old_target, &node_mapping, graph);
463
464 if let (Some(source), Some(target)) = (new_source, new_target) {
465 if source != target {
466 compressed_graph
468 .graph
469 .add_edge(source, target, edge_ref.weight().clone());
470 }
471 }
472 }
473
474 compressed_graph.inputs = graph
476 .inputs()
477 .iter()
478 .filter_map(|&idx| Self::find_actual_node(idx, &node_mapping, graph))
479 .collect();
480 compressed_graph.outputs = graph
481 .outputs()
482 .iter()
483 .filter_map(|&idx| Self::find_actual_node(idx, &node_mapping, graph))
484 .collect();
485
486 Ok(compressed_graph)
487 }
488
489 fn is_redundant_node(node: &Node) -> bool {
491 match node {
492 Node::Call(op, _) => op == "identity" || op == "noop",
493 _ => false,
494 }
495 }
496
497 fn find_actual_node(
499 start_idx: NodeIndex,
500 node_mapping: &HashMap<NodeIndex, Option<NodeIndex>>,
501 _graph: &FxGraph,
502 ) -> Option<NodeIndex> {
503 node_mapping.get(&start_idx).and_then(|&idx| idx)
504 }
505}
506
507#[derive(Debug)]
509pub struct PerformanceProfiler {
510 operation_times: RwLock<HashMap<String, Vec<Duration>>>,
511 bottlenecks: RwLock<Vec<PerformanceBottleneck>>,
512 profiling_enabled: bool,
513}
514
515#[derive(Debug, Clone, Serialize, Deserialize)]
516pub struct PerformanceBottleneck {
517 pub operation: String,
518 pub average_time: Duration,
519 pub frequency: u64,
520 pub impact_score: f64,
521 pub recommendations: Vec<String>,
522}
523
524#[derive(Debug, Clone, Serialize, Deserialize)]
525pub struct PerformanceReport {
526 pub total_operations: u64,
527 pub total_time: Duration,
528 pub average_operation_time: Duration,
529 pub bottlenecks: Vec<PerformanceBottleneck>,
530 pub optimization_suggestions: Vec<String>,
531}
532
533impl PerformanceProfiler {
534 pub fn new() -> Self {
536 Self {
537 operation_times: RwLock::new(HashMap::new()),
538 bottlenecks: RwLock::new(Vec::new()),
539 profiling_enabled: true,
540 }
541 }
542
543 pub fn set_profiling_enabled(&mut self, enabled: bool) {
545 self.profiling_enabled = enabled;
546 }
547
548 pub fn record_operation(&self, operation: &str, duration: Duration) {
550 if !self.profiling_enabled {
551 return;
552 }
553
554 let mut times = self
555 .operation_times
556 .write()
557 .expect("lock should not be poisoned");
558 times
559 .entry(operation.to_string())
560 .or_insert_with(Vec::new)
561 .push(duration);
562 }
563
564 pub fn profile_graph_execution<F>(
566 &self,
567 graph: &FxGraph,
568 executor: F,
569 ) -> TorshResult<PerformanceReport>
570 where
571 F: FnOnce(&FxGraph) -> TorshResult<()>,
572 {
573 let start_time = Instant::now();
574
575 executor(graph)?;
577
578 let total_time = start_time.elapsed();
579
580 self.analyze_bottlenecks();
582
583 let report = self.generate_report(total_time);
585 Ok(report)
586 }
587
588 fn analyze_bottlenecks(&self) {
590 let times = self
591 .operation_times
592 .read()
593 .expect("lock should not be poisoned");
594 let mut bottlenecks = Vec::new();
595
596 for (operation, durations) in times.iter() {
597 if durations.is_empty() {
598 continue;
599 }
600
601 let total_time: Duration = durations.iter().sum();
602 let average_time = total_time / durations.len() as u32;
603 let frequency = durations.len() as u64;
604
605 let impact_score = average_time.as_secs_f64() * frequency as f64;
607
608 let recommendations = self.generate_recommendations(operation, average_time, frequency);
610
611 if impact_score > 0.1 {
612 bottlenecks.push(PerformanceBottleneck {
614 operation: operation.clone(),
615 average_time,
616 frequency,
617 impact_score,
618 recommendations,
619 });
620 }
621 }
622
623 bottlenecks.sort_by(|a, b| {
625 b.impact_score
626 .partial_cmp(&a.impact_score)
627 .expect("impact_score should be comparable")
628 });
629
630 *self
631 .bottlenecks
632 .write()
633 .expect("lock should not be poisoned") = bottlenecks;
634 }
635
636 fn generate_recommendations(
638 &self,
639 operation: &str,
640 avg_time: Duration,
641 frequency: u64,
642 ) -> Vec<String> {
643 let mut recommendations = Vec::new();
644
645 if avg_time.as_millis() > 100 {
646 recommendations.push(format!(
647 "Consider optimizing '{}' operation - high execution time",
648 operation
649 ));
650 }
651
652 if frequency > 1000 {
653 recommendations.push(format!(
654 "Operation '{}' is called frequently - consider caching",
655 operation
656 ));
657 }
658
659 if operation.contains("conv") && avg_time.as_millis() > 50 {
660 recommendations.push(
661 "Consider using optimized convolution algorithms or GPU acceleration".to_string(),
662 );
663 }
664
665 if operation.contains("matmul") && avg_time.as_millis() > 20 {
666 recommendations.push(
667 "Consider using BLAS libraries or tensor cores for matrix multiplication"
668 .to_string(),
669 );
670 }
671
672 if recommendations.is_empty() {
673 recommendations.push("Performance seems adequate for this operation".to_string());
674 }
675
676 recommendations
677 }
678
679 fn generate_report(&self, total_time: Duration) -> PerformanceReport {
681 let times = self
682 .operation_times
683 .read()
684 .expect("lock should not be poisoned");
685 let bottlenecks = self
686 .bottlenecks
687 .read()
688 .expect("lock should not be poisoned")
689 .clone();
690
691 let total_operations: u64 = times.values().map(|v| v.len() as u64).sum();
692 let average_operation_time = if total_operations > 0 {
693 total_time / total_operations as u32
694 } else {
695 Duration::from_millis(0)
696 };
697
698 let optimization_suggestions = self.generate_global_optimizations(&bottlenecks);
699
700 PerformanceReport {
701 total_operations,
702 total_time,
703 average_operation_time,
704 bottlenecks,
705 optimization_suggestions,
706 }
707 }
708
709 fn generate_global_optimizations(&self, bottlenecks: &[PerformanceBottleneck]) -> Vec<String> {
711 let mut suggestions = Vec::new();
712
713 if bottlenecks.len() > 5 {
714 suggestions.push(
715 "Consider using graph optimization passes to reduce operation count".to_string(),
716 );
717 }
718
719 if bottlenecks
720 .iter()
721 .any(|b| b.operation.contains("copy") || b.operation.contains("transpose"))
722 {
723 suggestions
724 .push("Consider memory layout optimizations to reduce data movement".to_string());
725 }
726
727 if bottlenecks.iter().any(|b| b.frequency > 100) {
728 suggestions
729 .push("Enable operation caching for frequently used computations".to_string());
730 }
731
732 suggestions
733 .push("Consider using parallel execution for independent operations".to_string());
734 suggestions
735 .push("Enable compiler optimizations and use release build for production".to_string());
736
737 suggestions
738 }
739
740 pub fn clear(&self) {
742 self.operation_times
743 .write()
744 .expect("lock should not be poisoned")
745 .clear();
746 self.bottlenecks
747 .write()
748 .expect("lock should not be poisoned")
749 .clear();
750 }
751
752 pub fn bottlenecks(&self) -> Vec<PerformanceBottleneck> {
754 self.bottlenecks
755 .read()
756 .expect("lock should not be poisoned")
757 .clone()
758 }
759}
760
761impl Default for PerformanceProfiler {
762 fn default() -> Self {
763 Self::new()
764 }
765}
766
767#[cfg(test)]
768mod tests {
769 use super::*;
770 use crate::{Edge, FxGraph, Node};
771 use std::sync::Arc;
772
773 #[test]
774 fn test_parallel_traversal() {
775 let mut graph = FxGraph::new();
776 let input = graph.graph.add_node(Node::Input("x".to_string()));
777 let relu = graph
778 .graph
779 .add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
780 let output = graph.graph.add_node(Node::Output);
781
782 graph.graph.add_edge(
783 input,
784 relu,
785 Edge {
786 name: "x".to_string(),
787 },
788 );
789 graph.graph.add_edge(
790 relu,
791 output,
792 Edge {
793 name: "relu_out".to_string(),
794 },
795 );
796 graph.inputs.push(input);
797 graph.outputs.push(output);
798
799 let parallel_traversal = ParallelTraversal::new(Arc::new(graph));
800 let visited_nodes = Vec::new();
801 let visited_nodes = Arc::new(Mutex::new(visited_nodes));
802
803 let result = parallel_traversal.parallel_topological_traverse(|idx, _node| {
804 visited_nodes
805 .lock()
806 .expect("lock should not be poisoned")
807 .push(idx);
808 });
809
810 assert!(result.is_ok());
811 assert_eq!(
812 visited_nodes
813 .lock()
814 .expect("lock should not be poisoned")
815 .len(),
816 3
817 );
818 }
819
820 #[test]
821 fn test_graph_cache() {
822 let cache = GraphCache::new(100);
823
824 assert!(cache.get_operation("test_op").is_none());
826
827 cache.cache_operation(
828 "test_op".to_string(),
829 "result".to_string(),
830 Duration::from_millis(100),
831 );
832
833 let cached = cache.get_operation("test_op").unwrap();
834 assert_eq!(cached.result, "result");
835 assert_eq!(cached.access_count, 2);
836
837 let cached_again = cache.get_operation("test_op").unwrap();
839 assert_eq!(cached_again.access_count, 3);
840
841 let stats = cache.statistics();
842 assert_eq!(stats.hits, 2);
843 assert_eq!(stats.misses, 1);
844 }
845
846 #[test]
847 fn test_graph_compression() {
848 let mut graph = FxGraph::new();
849 let input1 = graph.graph.add_node(Node::Input("x".to_string()));
850 let input2 = graph.graph.add_node(Node::Input("x".to_string())); let relu = graph
852 .graph
853 .add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
854 let output = graph.graph.add_node(Node::Output);
855
856 graph.graph.add_edge(
857 input1,
858 relu,
859 Edge {
860 name: "x1".to_string(),
861 },
862 );
863 graph.graph.add_edge(
864 input2,
865 relu,
866 Edge {
867 name: "x2".to_string(),
868 },
869 );
870 graph.graph.add_edge(
871 relu,
872 output,
873 Edge {
874 name: "relu_out".to_string(),
875 },
876 );
877
878 let compressed = GraphCompression::deduplicate_operations(&graph).unwrap();
879
880 assert!(compressed.node_count() < graph.node_count());
882 }
883
884 #[test]
885 fn test_performance_profiler() {
886 let profiler = PerformanceProfiler::new();
887
888 profiler.record_operation("conv2d", Duration::from_millis(100));
890 profiler.record_operation("conv2d", Duration::from_millis(120));
891 profiler.record_operation("relu", Duration::from_millis(10));
892
893 let graph = FxGraph::new();
895
896 let report = profiler
897 .profile_graph_execution(&graph, |_| Ok(()))
898 .unwrap();
899
900 assert_eq!(report.total_operations, 3);
901 assert!(!report.bottlenecks.is_empty());
902 assert!(!report.optimization_suggestions.is_empty());
903 }
904
905 #[test]
906 fn test_cache_statistics() {
907 let cache = GraphCache::new(2); cache.cache_operation(
910 "op1".to_string(),
911 "result1".to_string(),
912 Duration::from_millis(50),
913 );
914 std::thread::sleep(Duration::from_millis(1));
916
917 cache.cache_operation(
918 "op2".to_string(),
919 "result2".to_string(),
920 Duration::from_millis(75),
921 );
922 std::thread::sleep(Duration::from_millis(1));
923
924 cache.cache_operation(
925 "op3".to_string(),
926 "result3".to_string(),
927 Duration::from_millis(100),
928 ); let stats = cache.statistics();
931 assert_eq!(stats.evictions, 1);
932
933 assert!(cache.get_operation("op1").is_none());
935 assert!(cache.get_operation("op2").is_some() || cache.get_operation("op3").is_some());
936
937 let op2_exists = cache.get_operation("op2").is_some();
939 let op3_exists = cache.get_operation("op3").is_some();
940 assert_eq!((op2_exists as usize) + (op3_exists as usize), 2);
941 }
942}