Skip to main content

torsh_fx/
performance.rs

1//! Performance optimization and analysis utilities for FX graphs
2//!
3//! This module provides advanced performance optimization features including:
4//! - Parallel graph traversal for large graphs
5//! - Graph caching and memoization for repeated operations
6//! - Graph compression techniques for reduced memory usage
7//! - Automatic performance profiling and bottleneck detection
8
9use crate::{FxGraph, Node, TorshResult};
10use petgraph::graph::NodeIndex;
11use petgraph::visit::EdgeRef;
12// SCIRS2 POLICY COMPLIANCE: Use scirs2_core::parallel_ops instead of direct rayon
13use scirs2_core::parallel_ops::*;
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, HashSet};
16use std::sync::{Arc, Mutex, RwLock};
17use std::time::{Duration, Instant};
18
19/// Parallel graph traversal utilities for large graphs
20pub struct ParallelTraversal {
21    graph: Arc<FxGraph>,
22    thread_pool_size: Option<usize>,
23}
24
25impl ParallelTraversal {
26    /// Create a new parallel traversal instance
27    pub fn new(graph: Arc<FxGraph>) -> Self {
28        Self {
29            graph,
30            thread_pool_size: None,
31        }
32    }
33
34    /// Set custom thread pool size (default: number of CPU cores)
35    pub fn with_thread_pool_size(mut self, size: usize) -> Self {
36        self.thread_pool_size = Some(size);
37        self
38    }
39
40    /// Perform parallel topological traversal of the graph
41    pub fn parallel_topological_traverse<F>(&self, visitor: F) -> TorshResult<()>
42    where
43        F: Fn(NodeIndex, &Node) + Send + Sync,
44    {
45        // Build dependency graph for topological ordering
46        let dependencies = self.build_dependency_map();
47        let visited = HashSet::new();
48        let mut ready_nodes = Vec::new();
49
50        // Find nodes with no dependencies
51        for (idx, _) in self.graph.nodes() {
52            if dependencies.get(&idx).map_or(true, |deps| deps.is_empty()) {
53                ready_nodes.push(idx);
54            }
55        }
56
57        let visited = Arc::new(Mutex::new(visited));
58        let dependencies = Arc::new(dependencies);
59
60        while !ready_nodes.is_empty() {
61            // Process ready nodes in parallel
62            ready_nodes.par_iter().for_each(|&idx| {
63                if let Some(node) = self.graph.get_node(idx) {
64                    visitor(idx, node);
65                    visited
66                        .lock()
67                        .expect("lock should not be poisoned")
68                        .insert(idx);
69                }
70            });
71
72            // Find newly ready nodes - optimized to avoid cloning the entire visited set
73            ready_nodes.clear();
74
75            for (idx, deps) in dependencies.iter() {
76                let visited_guard = visited.lock().expect("lock should not be poisoned");
77                if !visited_guard.contains(idx) {
78                    if deps.iter().all(|dep| visited_guard.contains(dep)) {
79                        ready_nodes.push(*idx);
80                    }
81                }
82                // Release the lock early by dropping the guard
83                drop(visited_guard);
84            }
85        }
86
87        Ok(())
88    }
89
90    /// Perform parallel depth-first search with work stealing
91    pub fn parallel_dfs<F>(&self, start_nodes: Vec<NodeIndex>, visitor: F) -> TorshResult<()>
92    where
93        F: Fn(NodeIndex, &Node) + Send + Sync,
94    {
95        let visited = Arc::new(Mutex::new(HashSet::new()));
96        let visitor = Arc::new(visitor);
97
98        start_nodes.into_par_iter().for_each(|start| {
99            self.dfs_worker(start, visited.clone(), visitor.clone());
100        });
101
102        Ok(())
103    }
104
105    /// Build a dependency map for topological traversal
106    fn build_dependency_map(&self) -> HashMap<NodeIndex, Vec<NodeIndex>> {
107        let mut dependencies = HashMap::new();
108
109        for (idx, _) in self.graph.nodes() {
110            dependencies.insert(idx, Vec::new());
111        }
112
113        // Build reverse dependency map
114        for edge_ref in self.graph.graph.edge_references() {
115            use petgraph::visit::EdgeRef;
116            let target = edge_ref.target();
117            let source = edge_ref.source();
118            dependencies
119                .get_mut(&target)
120                .expect("target node should exist in dependencies map")
121                .push(source);
122        }
123
124        dependencies
125    }
126
127    /// DFS worker for parallel traversal
128    fn dfs_worker<F>(
129        &self,
130        node: NodeIndex,
131        visited: Arc<Mutex<HashSet<NodeIndex>>>,
132        visitor: Arc<F>,
133    ) where
134        F: Fn(NodeIndex, &Node) + Send + Sync,
135    {
136        let mut stack = vec![node];
137
138        while let Some(current) = stack.pop() {
139            let already_visited = {
140                let mut v = visited.lock().expect("lock should not be poisoned");
141                if v.contains(&current) {
142                    true
143                } else {
144                    v.insert(current);
145                    false
146                }
147            };
148
149            if already_visited {
150                continue;
151            }
152
153            if let Some(node_data) = self.graph.get_node(current) {
154                visitor(current, node_data);
155            }
156
157            // Add neighbors to stack
158            for edge in self.graph.graph.edges(current) {
159                stack.push(edge.target());
160            }
161        }
162    }
163}
164
165/// Graph caching and memoization system
166#[derive(Debug)]
167pub struct GraphCache {
168    operation_cache: RwLock<HashMap<String, CachedResult>>,
169    subgraph_cache: RwLock<HashMap<String, Arc<FxGraph>>>,
170    cache_stats: Arc<Mutex<CacheStatistics>>,
171    max_cache_size: usize,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct CachedResult {
176    pub result: String,
177    pub computation_time: Duration,
178    pub access_count: u64,
179    pub last_accessed: std::time::SystemTime,
180}
181
182#[derive(Debug, Default, Clone)]
183pub struct CacheStatistics {
184    pub hits: u64,
185    pub misses: u64,
186    pub evictions: u64,
187    pub total_computation_time_saved: Duration,
188}
189
190impl GraphCache {
191    /// Create a new graph cache with specified maximum size
192    pub fn new(max_cache_size: usize) -> Self {
193        Self {
194            operation_cache: RwLock::new(HashMap::new()),
195            subgraph_cache: RwLock::new(HashMap::new()),
196            cache_stats: Arc::new(Mutex::new(CacheStatistics::default())),
197            max_cache_size,
198        }
199    }
200
201    /// Get cached operation result
202    pub fn get_operation(&self, key: &str) -> Option<CachedResult> {
203        let mut cache = self
204            .operation_cache
205            .write()
206            .expect("lock should not be poisoned");
207        if let Some(result) = cache.get_mut(key) {
208            result.access_count += 1;
209            result.last_accessed = std::time::SystemTime::now();
210            self.cache_stats
211                .lock()
212                .expect("lock should not be poisoned")
213                .hits += 1;
214            Some(result.clone())
215        } else {
216            self.cache_stats
217                .lock()
218                .expect("lock should not be poisoned")
219                .misses += 1;
220            None
221        }
222    }
223
224    /// Cache operation result
225    pub fn cache_operation(&self, key: String, result: String, computation_time: Duration) {
226        let mut cache = self
227            .operation_cache
228            .write()
229            .expect("lock should not be poisoned");
230
231        // Evict oldest entries if cache is full
232        if cache.len() >= self.max_cache_size {
233            self.evict_lru_operation(&mut cache);
234        }
235
236        let cached_result = CachedResult {
237            result,
238            computation_time,
239            access_count: 1,
240            last_accessed: std::time::SystemTime::now(),
241        };
242
243        cache.insert(key, cached_result);
244    }
245
246    /// Get cached subgraph
247    pub fn get_subgraph(&self, key: &str) -> Option<Arc<FxGraph>> {
248        let cache = self
249            .subgraph_cache
250            .read()
251            .expect("lock should not be poisoned");
252        if let Some(graph) = cache.get(key) {
253            self.cache_stats
254                .lock()
255                .expect("lock should not be poisoned")
256                .hits += 1;
257            Some(graph.clone())
258        } else {
259            self.cache_stats
260                .lock()
261                .expect("lock should not be poisoned")
262                .misses += 1;
263            None
264        }
265    }
266
267    /// Cache subgraph
268    pub fn cache_subgraph(&self, key: String, graph: Arc<FxGraph>) {
269        let mut cache = self
270            .subgraph_cache
271            .write()
272            .expect("lock should not be poisoned");
273
274        // Evict oldest entries if cache is full
275        if cache.len() >= self.max_cache_size {
276            self.evict_lru_subgraph(&mut cache);
277        }
278
279        cache.insert(key, graph);
280    }
281
282    /// Get cache statistics
283    pub fn statistics(&self) -> CacheStatistics {
284        self.cache_stats
285            .lock()
286            .expect("lock should not be poisoned")
287            .clone()
288    }
289
290    /// Clear all caches
291    pub fn clear(&self) {
292        self.operation_cache
293            .write()
294            .expect("lock should not be poisoned")
295            .clear();
296        self.subgraph_cache
297            .write()
298            .expect("lock should not be poisoned")
299            .clear();
300        *self
301            .cache_stats
302            .lock()
303            .expect("lock should not be poisoned") = CacheStatistics::default();
304    }
305
306    /// Evict least recently used operation
307    fn evict_lru_operation(&self, cache: &mut HashMap<String, CachedResult>) {
308        if let Some(lru_key) = cache
309            .iter()
310            .min_by_key(|(_, result)| result.last_accessed)
311            .map(|(key, _)| key.clone())
312        {
313            cache.remove(&lru_key);
314            self.cache_stats
315                .lock()
316                .expect("lock should not be poisoned")
317                .evictions += 1;
318        }
319    }
320
321    /// Evict least recently used subgraph (simplified LRU)
322    fn evict_lru_subgraph(&self, cache: &mut HashMap<String, Arc<FxGraph>>) {
323        if let Some(key) = cache.keys().next().cloned() {
324            cache.remove(&key);
325            self.cache_stats
326                .lock()
327                .expect("lock should not be poisoned")
328                .evictions += 1;
329        }
330    }
331}
332
333/// Graph compression utilities
334pub struct GraphCompression;
335
336impl GraphCompression {
337    /// Compress graph using operation deduplication
338    pub fn deduplicate_operations(graph: &FxGraph) -> TorshResult<FxGraph> {
339        let mut compressed_graph = FxGraph::new();
340        let mut operation_map: HashMap<String, NodeIndex> = HashMap::new();
341        let mut node_mapping: HashMap<NodeIndex, NodeIndex> = HashMap::new();
342
343        // First pass: deduplicate operations
344        for (old_idx, node) in graph.nodes() {
345            let operation_key = Self::operation_key(node);
346
347            if let Some(&existing_idx) = operation_map.get(&operation_key) {
348                // Reuse existing operation
349                node_mapping.insert(old_idx, existing_idx);
350            } else {
351                // Create new operation
352                let new_idx = compressed_graph.graph.add_node(node.clone());
353                operation_map.insert(operation_key, new_idx);
354                node_mapping.insert(old_idx, new_idx);
355            }
356        }
357
358        // Second pass: rebuild edges
359        for edge_ref in graph.graph.edge_references() {
360            use petgraph::visit::EdgeRef;
361            let old_source = edge_ref.source();
362            let old_target = edge_ref.target();
363
364            if let (Some(&new_source), Some(&new_target)) =
365                (node_mapping.get(&old_source), node_mapping.get(&old_target))
366            {
367                // Avoid duplicate edges
368                if !compressed_graph
369                    .graph
370                    .edges_connecting(new_source, new_target)
371                    .next()
372                    .is_some()
373                {
374                    compressed_graph.graph.add_edge(
375                        new_source,
376                        new_target,
377                        edge_ref.weight().clone(),
378                    );
379                }
380            }
381        }
382
383        // Update inputs and outputs
384        compressed_graph.inputs = graph
385            .inputs()
386            .iter()
387            .filter_map(|&idx| node_mapping.get(&idx).copied())
388            .collect();
389        compressed_graph.outputs = graph
390            .outputs()
391            .iter()
392            .filter_map(|&idx| node_mapping.get(&idx).copied())
393            .collect();
394
395        Ok(compressed_graph)
396    }
397
398    /// Create operation key for deduplication
399    fn operation_key(node: &Node) -> String {
400        match node {
401            Node::Input(name) => format!("input:{name}"),
402            Node::Call(op, args) => {
403                let args_str = args.join(",");
404                format!("call:{op}:{args_str}")
405            }
406            Node::Output => "output".into(),
407            Node::Conditional {
408                condition,
409                then_branch,
410                else_branch,
411            } => {
412                format!(
413                    "conditional:{}:{}:{}",
414                    condition,
415                    then_branch.join(","),
416                    else_branch.join(",")
417                )
418            }
419            Node::Loop {
420                condition,
421                body,
422                loop_vars,
423            } => {
424                format!(
425                    "loop:{}:{}:{}",
426                    condition,
427                    body.join(","),
428                    loop_vars.join(",")
429                )
430            }
431            Node::Merge { inputs } => {
432                let inputs_str = inputs.join(",");
433                format!("merge:{inputs_str}")
434            }
435            Node::GetAttr { target, attr } => format!("getattr:{target}:{attr}"),
436        }
437    }
438
439    /// Compress graph by removing redundant nodes
440    pub fn remove_redundant_nodes(graph: &FxGraph) -> TorshResult<FxGraph> {
441        let mut compressed_graph = FxGraph::new();
442        let mut node_mapping: HashMap<NodeIndex, Option<NodeIndex>> = HashMap::new();
443
444        // Identify redundant nodes (e.g., identity operations)
445        for (old_idx, node) in graph.nodes() {
446            if Self::is_redundant_node(node) {
447                node_mapping.insert(old_idx, None); // Mark for removal
448            } else {
449                let new_idx = compressed_graph.graph.add_node(node.clone());
450                node_mapping.insert(old_idx, Some(new_idx));
451            }
452        }
453
454        // Rebuild edges, skipping redundant nodes
455        for edge_ref in graph.graph.edge_references() {
456            use petgraph::visit::EdgeRef;
457            let old_source = edge_ref.source();
458            let old_target = edge_ref.target();
459
460            // Find actual source and target (skipping redundant nodes)
461            let new_source = Self::find_actual_node(old_source, &node_mapping, graph);
462            let new_target = Self::find_actual_node(old_target, &node_mapping, graph);
463
464            if let (Some(source), Some(target)) = (new_source, new_target) {
465                if source != target {
466                    // Avoid self-loops
467                    compressed_graph
468                        .graph
469                        .add_edge(source, target, edge_ref.weight().clone());
470                }
471            }
472        }
473
474        // Update inputs and outputs
475        compressed_graph.inputs = graph
476            .inputs()
477            .iter()
478            .filter_map(|&idx| Self::find_actual_node(idx, &node_mapping, graph))
479            .collect();
480        compressed_graph.outputs = graph
481            .outputs()
482            .iter()
483            .filter_map(|&idx| Self::find_actual_node(idx, &node_mapping, graph))
484            .collect();
485
486        Ok(compressed_graph)
487    }
488
489    /// Check if a node is redundant (e.g., identity operation)
490    fn is_redundant_node(node: &Node) -> bool {
491        match node {
492            Node::Call(op, _) => op == "identity" || op == "noop",
493            _ => false,
494        }
495    }
496
497    /// Find the actual node after skipping redundant nodes
498    fn find_actual_node(
499        start_idx: NodeIndex,
500        node_mapping: &HashMap<NodeIndex, Option<NodeIndex>>,
501        _graph: &FxGraph,
502    ) -> Option<NodeIndex> {
503        node_mapping.get(&start_idx).and_then(|&idx| idx)
504    }
505}
506
507/// Automatic performance profiling and bottleneck detection
508#[derive(Debug)]
509pub struct PerformanceProfiler {
510    operation_times: RwLock<HashMap<String, Vec<Duration>>>,
511    bottlenecks: RwLock<Vec<PerformanceBottleneck>>,
512    profiling_enabled: bool,
513}
514
515#[derive(Debug, Clone, Serialize, Deserialize)]
516pub struct PerformanceBottleneck {
517    pub operation: String,
518    pub average_time: Duration,
519    pub frequency: u64,
520    pub impact_score: f64,
521    pub recommendations: Vec<String>,
522}
523
524#[derive(Debug, Clone, Serialize, Deserialize)]
525pub struct PerformanceReport {
526    pub total_operations: u64,
527    pub total_time: Duration,
528    pub average_operation_time: Duration,
529    pub bottlenecks: Vec<PerformanceBottleneck>,
530    pub optimization_suggestions: Vec<String>,
531}
532
533impl PerformanceProfiler {
534    /// Create a new performance profiler
535    pub fn new() -> Self {
536        Self {
537            operation_times: RwLock::new(HashMap::new()),
538            bottlenecks: RwLock::new(Vec::new()),
539            profiling_enabled: true,
540        }
541    }
542
543    /// Enable or disable profiling
544    pub fn set_profiling_enabled(&mut self, enabled: bool) {
545        self.profiling_enabled = enabled;
546    }
547
548    /// Record operation execution time
549    pub fn record_operation(&self, operation: &str, duration: Duration) {
550        if !self.profiling_enabled {
551            return;
552        }
553
554        let mut times = self
555            .operation_times
556            .write()
557            .expect("lock should not be poisoned");
558        times
559            .entry(operation.to_string())
560            .or_insert_with(Vec::new)
561            .push(duration);
562    }
563
564    /// Profile graph execution
565    pub fn profile_graph_execution<F>(
566        &self,
567        graph: &FxGraph,
568        executor: F,
569    ) -> TorshResult<PerformanceReport>
570    where
571        F: FnOnce(&FxGraph) -> TorshResult<()>,
572    {
573        let start_time = Instant::now();
574
575        // Execute the graph
576        executor(graph)?;
577
578        let total_time = start_time.elapsed();
579
580        // Analyze performance data
581        self.analyze_bottlenecks();
582
583        // Generate report
584        let report = self.generate_report(total_time);
585        Ok(report)
586    }
587
588    /// Detect and analyze performance bottlenecks
589    fn analyze_bottlenecks(&self) {
590        let times = self
591            .operation_times
592            .read()
593            .expect("lock should not be poisoned");
594        let mut bottlenecks = Vec::new();
595
596        for (operation, durations) in times.iter() {
597            if durations.is_empty() {
598                continue;
599            }
600
601            let total_time: Duration = durations.iter().sum();
602            let average_time = total_time / durations.len() as u32;
603            let frequency = durations.len() as u64;
604
605            // Calculate impact score (average time * frequency)
606            let impact_score = average_time.as_secs_f64() * frequency as f64;
607
608            // Generate recommendations
609            let recommendations = self.generate_recommendations(operation, average_time, frequency);
610
611            if impact_score > 0.1 {
612                // Threshold for considering as bottleneck
613                bottlenecks.push(PerformanceBottleneck {
614                    operation: operation.clone(),
615                    average_time,
616                    frequency,
617                    impact_score,
618                    recommendations,
619                });
620            }
621        }
622
623        // Sort by impact score
624        bottlenecks.sort_by(|a, b| {
625            b.impact_score
626                .partial_cmp(&a.impact_score)
627                .expect("impact_score should be comparable")
628        });
629
630        *self
631            .bottlenecks
632            .write()
633            .expect("lock should not be poisoned") = bottlenecks;
634    }
635
636    /// Generate optimization recommendations
637    fn generate_recommendations(
638        &self,
639        operation: &str,
640        avg_time: Duration,
641        frequency: u64,
642    ) -> Vec<String> {
643        let mut recommendations = Vec::new();
644
645        if avg_time.as_millis() > 100 {
646            recommendations.push(format!(
647                "Consider optimizing '{}' operation - high execution time",
648                operation
649            ));
650        }
651
652        if frequency > 1000 {
653            recommendations.push(format!(
654                "Operation '{}' is called frequently - consider caching",
655                operation
656            ));
657        }
658
659        if operation.contains("conv") && avg_time.as_millis() > 50 {
660            recommendations.push(
661                "Consider using optimized convolution algorithms or GPU acceleration".to_string(),
662            );
663        }
664
665        if operation.contains("matmul") && avg_time.as_millis() > 20 {
666            recommendations.push(
667                "Consider using BLAS libraries or tensor cores for matrix multiplication"
668                    .to_string(),
669            );
670        }
671
672        if recommendations.is_empty() {
673            recommendations.push("Performance seems adequate for this operation".to_string());
674        }
675
676        recommendations
677    }
678
679    /// Generate comprehensive performance report
680    fn generate_report(&self, total_time: Duration) -> PerformanceReport {
681        let times = self
682            .operation_times
683            .read()
684            .expect("lock should not be poisoned");
685        let bottlenecks = self
686            .bottlenecks
687            .read()
688            .expect("lock should not be poisoned")
689            .clone();
690
691        let total_operations: u64 = times.values().map(|v| v.len() as u64).sum();
692        let average_operation_time = if total_operations > 0 {
693            total_time / total_operations as u32
694        } else {
695            Duration::from_millis(0)
696        };
697
698        let optimization_suggestions = self.generate_global_optimizations(&bottlenecks);
699
700        PerformanceReport {
701            total_operations,
702            total_time,
703            average_operation_time,
704            bottlenecks,
705            optimization_suggestions,
706        }
707    }
708
709    /// Generate global optimization suggestions
710    fn generate_global_optimizations(&self, bottlenecks: &[PerformanceBottleneck]) -> Vec<String> {
711        let mut suggestions = Vec::new();
712
713        if bottlenecks.len() > 5 {
714            suggestions.push(
715                "Consider using graph optimization passes to reduce operation count".to_string(),
716            );
717        }
718
719        if bottlenecks
720            .iter()
721            .any(|b| b.operation.contains("copy") || b.operation.contains("transpose"))
722        {
723            suggestions
724                .push("Consider memory layout optimizations to reduce data movement".to_string());
725        }
726
727        if bottlenecks.iter().any(|b| b.frequency > 100) {
728            suggestions
729                .push("Enable operation caching for frequently used computations".to_string());
730        }
731
732        suggestions
733            .push("Consider using parallel execution for independent operations".to_string());
734        suggestions
735            .push("Enable compiler optimizations and use release build for production".to_string());
736
737        suggestions
738    }
739
740    /// Clear all profiling data
741    pub fn clear(&self) {
742        self.operation_times
743            .write()
744            .expect("lock should not be poisoned")
745            .clear();
746        self.bottlenecks
747            .write()
748            .expect("lock should not be poisoned")
749            .clear();
750    }
751
752    /// Get current bottlenecks
753    pub fn bottlenecks(&self) -> Vec<PerformanceBottleneck> {
754        self.bottlenecks
755            .read()
756            .expect("lock should not be poisoned")
757            .clone()
758    }
759}
760
761impl Default for PerformanceProfiler {
762    fn default() -> Self {
763        Self::new()
764    }
765}
766
767#[cfg(test)]
768mod tests {
769    use super::*;
770    use crate::{Edge, FxGraph, Node};
771    use std::sync::Arc;
772
773    #[test]
774    fn test_parallel_traversal() {
775        let mut graph = FxGraph::new();
776        let input = graph.graph.add_node(Node::Input("x".to_string()));
777        let relu = graph
778            .graph
779            .add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
780        let output = graph.graph.add_node(Node::Output);
781
782        graph.graph.add_edge(
783            input,
784            relu,
785            Edge {
786                name: "x".to_string(),
787            },
788        );
789        graph.graph.add_edge(
790            relu,
791            output,
792            Edge {
793                name: "relu_out".to_string(),
794            },
795        );
796        graph.inputs.push(input);
797        graph.outputs.push(output);
798
799        let parallel_traversal = ParallelTraversal::new(Arc::new(graph));
800        let visited_nodes = Vec::new();
801        let visited_nodes = Arc::new(Mutex::new(visited_nodes));
802
803        let result = parallel_traversal.parallel_topological_traverse(|idx, _node| {
804            visited_nodes
805                .lock()
806                .expect("lock should not be poisoned")
807                .push(idx);
808        });
809
810        assert!(result.is_ok());
811        assert_eq!(
812            visited_nodes
813                .lock()
814                .expect("lock should not be poisoned")
815                .len(),
816            3
817        );
818    }
819
820    #[test]
821    fn test_graph_cache() {
822        let cache = GraphCache::new(100);
823
824        // Test operation caching
825        assert!(cache.get_operation("test_op").is_none());
826
827        cache.cache_operation(
828            "test_op".to_string(),
829            "result".to_string(),
830            Duration::from_millis(100),
831        );
832
833        let cached = cache.get_operation("test_op").unwrap();
834        assert_eq!(cached.result, "result");
835        assert_eq!(cached.access_count, 2);
836
837        // Test cache hit
838        let cached_again = cache.get_operation("test_op").unwrap();
839        assert_eq!(cached_again.access_count, 3);
840
841        let stats = cache.statistics();
842        assert_eq!(stats.hits, 2);
843        assert_eq!(stats.misses, 1);
844    }
845
846    #[test]
847    fn test_graph_compression() {
848        let mut graph = FxGraph::new();
849        let input1 = graph.graph.add_node(Node::Input("x".to_string()));
850        let input2 = graph.graph.add_node(Node::Input("x".to_string())); // Duplicate
851        let relu = graph
852            .graph
853            .add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
854        let output = graph.graph.add_node(Node::Output);
855
856        graph.graph.add_edge(
857            input1,
858            relu,
859            Edge {
860                name: "x1".to_string(),
861            },
862        );
863        graph.graph.add_edge(
864            input2,
865            relu,
866            Edge {
867                name: "x2".to_string(),
868            },
869        );
870        graph.graph.add_edge(
871            relu,
872            output,
873            Edge {
874                name: "relu_out".to_string(),
875            },
876        );
877
878        let compressed = GraphCompression::deduplicate_operations(&graph).unwrap();
879
880        // Should have fewer nodes due to deduplication
881        assert!(compressed.node_count() < graph.node_count());
882    }
883
884    #[test]
885    fn test_performance_profiler() {
886        let profiler = PerformanceProfiler::new();
887
888        // Record some operations
889        profiler.record_operation("conv2d", Duration::from_millis(100));
890        profiler.record_operation("conv2d", Duration::from_millis(120));
891        profiler.record_operation("relu", Duration::from_millis(10));
892
893        // Create a simple graph for testing
894        let graph = FxGraph::new();
895
896        let report = profiler
897            .profile_graph_execution(&graph, |_| Ok(()))
898            .unwrap();
899
900        assert_eq!(report.total_operations, 3);
901        assert!(!report.bottlenecks.is_empty());
902        assert!(!report.optimization_suggestions.is_empty());
903    }
904
905    #[test]
906    fn test_cache_statistics() {
907        let cache = GraphCache::new(2); // Small cache for testing eviction
908
909        cache.cache_operation(
910            "op1".to_string(),
911            "result1".to_string(),
912            Duration::from_millis(50),
913        );
914        // Small delay to ensure different timestamps
915        std::thread::sleep(Duration::from_millis(1));
916
917        cache.cache_operation(
918            "op2".to_string(),
919            "result2".to_string(),
920            Duration::from_millis(75),
921        );
922        std::thread::sleep(Duration::from_millis(1));
923
924        cache.cache_operation(
925            "op3".to_string(),
926            "result3".to_string(),
927            Duration::from_millis(100),
928        ); // Should trigger eviction
929
930        let stats = cache.statistics();
931        assert_eq!(stats.evictions, 1);
932
933        // Verify cache contents - op1 should be evicted as it was oldest
934        assert!(cache.get_operation("op1").is_none());
935        assert!(cache.get_operation("op2").is_some() || cache.get_operation("op3").is_some());
936
937        // Verify we have exactly 2 items cached
938        let op2_exists = cache.get_operation("op2").is_some();
939        let op3_exists = cache.get_operation("op3").is_some();
940        assert_eq!((op2_exists as usize) + (op3_exists as usize), 2);
941    }
942}