Skip to main content

tensorlogic_ir/graph/
pgo.rs

1//! Profile-guided optimization for EinsumGraph.
2//!
3//! This module implements profile-guided optimization (PGO) that uses runtime
4//! profiling data to make better optimization decisions.
5//!
6//! # Features
7//!
8//! - **Execution profiling**: Collect runtime statistics (timing, memory, cache misses)
9//! - **Hotspot identification**: Find performance bottlenecks
10//! - **Adaptive optimization**: Choose optimizations based on actual behavior
11//! - **Feedback-directed compilation**: Use profile data to guide compilation
12//!
13//! # Example
14//!
15//! ```
16//! use tensorlogic_ir::{ExecutionProfile, ProfileGuidedOptimizer, OptimizationHint};
17//! use tensorlogic_ir::EinsumGraph;
18//!
19//! // Collect profile during execution
20//! let mut profile = ExecutionProfile::new();
21//! // ... execute graph and collect stats ...
22//!
23//! // Use profile to optimize
24//! let mut graph = EinsumGraph::new();
25//! let optimizer = ProfileGuidedOptimizer::new(profile);
26//! let hints = optimizer.analyze(&graph);
27//! ```
28
29use serde::{Deserialize, Serialize};
30use std::collections::HashMap;
31use std::time::Duration;
32
33use crate::graph::EinsumGraph;
34use crate::IrError;
35
36/// Node identifier (index into graph.nodes)
37pub type NodeId = usize;
38/// Tensor identifier (index into graph.tensors)
39pub type TensorId = usize;
40
41/// Runtime execution statistics for a single node.
42#[derive(Clone, Debug, Default, Serialize, Deserialize)]
43pub struct NodeStats {
44    /// Number of times executed
45    pub execution_count: u64,
46    /// Total time spent executing this node
47    pub total_time: Duration,
48    /// Minimum execution time
49    pub min_time: Duration,
50    /// Maximum execution time
51    pub max_time: Duration,
52    /// Total memory allocated (bytes)
53    pub memory_allocated: u64,
54    /// Peak memory used (bytes)
55    pub peak_memory: u64,
56    /// Cache misses (if available)
57    pub cache_misses: Option<u64>,
58    /// FLOPs executed
59    pub flops: Option<u64>,
60}
61
62impl NodeStats {
63    pub fn new() -> Self {
64        Self::default()
65    }
66
67    /// Record an execution
68    pub fn record_execution(&mut self, duration: Duration, memory: u64) {
69        self.execution_count += 1;
70        self.total_time += duration;
71
72        if self.execution_count == 1 {
73            self.min_time = duration;
74            self.max_time = duration;
75        } else {
76            if duration < self.min_time {
77                self.min_time = duration;
78            }
79            if duration > self.max_time {
80                self.max_time = duration;
81            }
82        }
83
84        self.memory_allocated += memory;
85        if memory > self.peak_memory {
86            self.peak_memory = memory;
87        }
88    }
89
90    /// Get average execution time
91    pub fn avg_time(&self) -> Duration {
92        if self.execution_count > 0 {
93            self.total_time / self.execution_count as u32
94        } else {
95            Duration::ZERO
96        }
97    }
98
99    /// Get time variance (max - min)
100    pub fn time_variance(&self) -> Duration {
101        self.max_time.saturating_sub(self.min_time)
102    }
103
104    /// Check if this is a hot node (frequently executed)
105    pub fn is_hot(&self, threshold: u64) -> bool {
106        self.execution_count >= threshold
107    }
108
109    /// Get performance score (higher is worse - indicates bottleneck)
110    pub fn performance_score(&self) -> f64 {
111        let time_weight = self.total_time.as_secs_f64();
112        let memory_weight = self.peak_memory as f64 / 1_000_000.0; // MB
113        let execution_weight = self.execution_count as f64;
114
115        time_weight * 0.5 + memory_weight * 0.3 + execution_weight * 0.2
116    }
117}
118
119/// Execution profile for entire graph.
120#[derive(Clone, Debug, Default, Serialize, Deserialize)]
121pub struct ExecutionProfile {
122    /// Per-node statistics
123    pub node_stats: HashMap<NodeId, NodeStats>,
124    /// Per-tensor statistics (size, reuse count)
125    pub tensor_stats: HashMap<TensorId, TensorStats>,
126    /// Total graph executions
127    pub total_executions: u64,
128    /// Critical path (longest execution chain)
129    pub critical_path: Vec<NodeId>,
130}
131
132impl ExecutionProfile {
133    pub fn new() -> Self {
134        Self::default()
135    }
136
137    /// Record node execution
138    pub fn record_node(&mut self, node_id: NodeId, duration: Duration, memory: u64) {
139        self.node_stats
140            .entry(node_id)
141            .or_default()
142            .record_execution(duration, memory);
143    }
144
145    /// Record tensor access
146    pub fn record_tensor_access(&mut self, tensor_id: TensorId, size: usize) {
147        self.tensor_stats
148            .entry(tensor_id)
149            .or_insert_with(|| TensorStats::new(size))
150            .record_access();
151    }
152
153    /// Get hot nodes (top N by performance score)
154    pub fn get_hot_nodes(&self, n: usize) -> Vec<(NodeId, f64)> {
155        let mut scores: Vec<_> = self
156            .node_stats
157            .iter()
158            .map(|(id, stats)| (*id, stats.performance_score()))
159            .collect();
160
161        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
162        scores.truncate(n);
163        scores
164    }
165
166    /// Get memory-intensive nodes
167    pub fn get_memory_intensive_nodes(&self, threshold: u64) -> Vec<NodeId> {
168        self.node_stats
169            .iter()
170            .filter(|(_, stats)| stats.peak_memory >= threshold)
171            .map(|(id, _)| *id)
172            .collect()
173    }
174
175    /// Merge another profile (for multi-run averaging)
176    pub fn merge(&mut self, other: &ExecutionProfile) {
177        for (node_id, other_stats) in &other.node_stats {
178            let stats = self.node_stats.entry(*node_id).or_default();
179
180            stats.execution_count += other_stats.execution_count;
181            stats.total_time += other_stats.total_time;
182            stats.memory_allocated += other_stats.memory_allocated;
183
184            if other_stats.min_time < stats.min_time
185                || stats.execution_count == other_stats.execution_count
186            {
187                stats.min_time = other_stats.min_time;
188            }
189            if other_stats.max_time > stats.max_time {
190                stats.max_time = other_stats.max_time;
191            }
192            if other_stats.peak_memory > stats.peak_memory {
193                stats.peak_memory = other_stats.peak_memory;
194            }
195        }
196
197        for (tensor_id, other_tensor_stats) in &other.tensor_stats {
198            let tensor_stats = self
199                .tensor_stats
200                .entry(*tensor_id)
201                .or_insert_with(|| TensorStats::new(other_tensor_stats.size_bytes));
202
203            tensor_stats.access_count += other_tensor_stats.access_count;
204            tensor_stats.last_access_time = tensor_stats
205                .last_access_time
206                .max(other_tensor_stats.last_access_time);
207        }
208
209        self.total_executions += other.total_executions;
210    }
211
212    /// Export profile to JSON
213    pub fn to_json(&self) -> Result<String, IrError> {
214        serde_json::to_string_pretty(self).map_err(|e| IrError::SerializationError(e.to_string()))
215    }
216
217    /// Import profile from JSON
218    pub fn from_json(json: &str) -> Result<Self, IrError> {
219        serde_json::from_str(json).map_err(|e| IrError::SerializationError(e.to_string()))
220    }
221}
222
223/// Tensor usage statistics.
224#[derive(Clone, Debug, Serialize, Deserialize)]
225pub struct TensorStats {
226    /// Tensor size in bytes
227    pub size_bytes: usize,
228    /// Number of accesses
229    pub access_count: u64,
230    /// Last access time (for liveness analysis)
231    pub last_access_time: u64,
232}
233
234impl TensorStats {
235    pub fn new(size_bytes: usize) -> Self {
236        TensorStats {
237            size_bytes,
238            access_count: 0,
239            last_access_time: 0,
240        }
241    }
242
243    pub fn record_access(&mut self) {
244        self.access_count += 1;
245        self.last_access_time = self.access_count;
246    }
247
248    /// Check if this tensor is frequently reused
249    pub fn is_reused(&self, threshold: u64) -> bool {
250        self.access_count >= threshold
251    }
252}
253
254/// Optimization hint derived from profiling.
255#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
256pub enum OptimizationHint {
257    /// Fuse these nodes together
258    FuseNodes(Vec<NodeId>),
259    /// Parallelize these independent nodes
260    Parallelize(Vec<NodeId>),
261    /// Cache this tensor in fast memory
262    CacheTensor(TensorId),
263    /// Use in-place operation for this node
264    InPlaceOp(NodeId),
265    /// Prefetch this tensor
266    Prefetch(TensorId),
267    /// Tile this operation
268    TileOperation { node: NodeId, tile_size: usize },
269    /// Reorder operations for better cache locality
270    ReorderOps(Vec<NodeId>),
271    /// Allocate large buffer for this tensor
272    PreAllocate { tensor: TensorId, size: usize },
273}
274
275/// Profile-guided optimizer.
276#[derive(Clone, Debug)]
277pub struct ProfileGuidedOptimizer {
278    profile: ExecutionProfile,
279    /// Hotness threshold (execution count)
280    hot_threshold: u64,
281    /// Memory threshold for considering a node memory-intensive (MB)
282    memory_threshold: u64,
283}
284
285impl ProfileGuidedOptimizer {
286    pub fn new(profile: ExecutionProfile) -> Self {
287        ProfileGuidedOptimizer {
288            profile,
289            hot_threshold: 10,
290            memory_threshold: 100 * 1024 * 1024, // 100 MB
291        }
292    }
293
294    /// Set hotness threshold
295    pub fn with_hot_threshold(mut self, threshold: u64) -> Self {
296        self.hot_threshold = threshold;
297        self
298    }
299
300    /// Set memory threshold
301    pub fn with_memory_threshold(mut self, threshold: u64) -> Self {
302        self.memory_threshold = threshold;
303        self
304    }
305
306    /// Analyze graph and generate optimization hints
307    pub fn analyze(&self, graph: &EinsumGraph) -> Vec<OptimizationHint> {
308        let mut hints = Vec::new();
309
310        // 1. Identify hot nodes for fusion
311        let hot_nodes = self.profile.get_hot_nodes(10);
312        if hot_nodes.len() >= 2 {
313            let node_ids: Vec<_> = hot_nodes.iter().map(|(id, _)| *id).collect();
314            hints.push(OptimizationHint::FuseNodes(node_ids));
315        }
316
317        // 2. Identify memory-intensive operations
318        let memory_nodes = self
319            .profile
320            .get_memory_intensive_nodes(self.memory_threshold);
321        for node_id in memory_nodes {
322            // Suggest in-place operations
323            hints.push(OptimizationHint::InPlaceOp(node_id));
324
325            // Check if we can tile the operation
326            if self.is_tileable(node_id, graph) {
327                hints.push(OptimizationHint::TileOperation {
328                    node: node_id,
329                    tile_size: 1024, // Default tile size
330                });
331            }
332        }
333
334        // 3. Identify frequently reused tensors for caching
335        for (tensor_id, stats) in &self.profile.tensor_stats {
336            if stats.is_reused(self.hot_threshold) {
337                hints.push(OptimizationHint::CacheTensor(*tensor_id));
338            }
339
340            // Pre-allocate large tensors
341            if stats.size_bytes > 1024 * 1024 {
342                // > 1 MB
343                hints.push(OptimizationHint::PreAllocate {
344                    tensor: *tensor_id,
345                    size: stats.size_bytes,
346                });
347            }
348        }
349
350        // 4. Find independent operations for parallelization
351        let parallel_groups = self.find_parallel_groups(graph);
352        for group in parallel_groups {
353            if group.len() >= 2 {
354                hints.push(OptimizationHint::Parallelize(group));
355            }
356        }
357
358        hints
359    }
360
361    /// Check if a node can be tiled
362    fn is_tileable(&self, _node_id: NodeId, _graph: &EinsumGraph) -> bool {
363        // Simplified check - would analyze operation type and dimensions
364        true
365    }
366
367    /// Find groups of independent operations that can be parallelized
368    fn find_parallel_groups(&self, graph: &EinsumGraph) -> Vec<Vec<NodeId>> {
369        let mut groups = Vec::new();
370
371        // Simple algorithm: nodes at the same depth with no dependencies
372        let depths = self.compute_depths(graph);
373        let mut depth_map: HashMap<usize, Vec<NodeId>> = HashMap::new();
374
375        for (node_id, depth) in depths {
376            depth_map.entry(depth).or_default().push(node_id);
377        }
378
379        for (_, nodes) in depth_map {
380            if nodes.len() >= 2 {
381                groups.push(nodes);
382            }
383        }
384
385        groups
386    }
387
388    /// Compute depth of each node in the graph
389    fn compute_depths(&self, graph: &EinsumGraph) -> HashMap<NodeId, usize> {
390        let mut depths = HashMap::new();
391
392        for node_id in 0..graph.nodes.len() {
393            depths.insert(
394                node_id,
395                self.compute_node_depth(node_id, graph, &mut HashMap::new()),
396            );
397        }
398
399        depths
400    }
401
402    #[allow(clippy::only_used_in_recursion)]
403    fn compute_node_depth(
404        &self,
405        node_id: NodeId,
406        graph: &EinsumGraph,
407        memo: &mut HashMap<NodeId, usize>,
408    ) -> usize {
409        if let Some(&depth) = memo.get(&node_id) {
410            return depth;
411        }
412
413        let node = &graph.nodes[node_id];
414        let input_depths: Vec<_> = node
415            .inputs
416            .iter()
417            .filter_map(|&tensor_id| {
418                // Find the node that produces this tensor
419                graph.nodes.iter().enumerate().find_map(|(id, n)| {
420                    if n.outputs.contains(&tensor_id) {
421                        Some(self.compute_node_depth(id, graph, memo))
422                    } else {
423                        None
424                    }
425                })
426            })
427            .collect();
428
429        let depth = if input_depths.is_empty() {
430            0
431        } else {
432            input_depths.into_iter().max().unwrap() + 1
433        };
434
435        memo.insert(node_id, depth);
436        depth
437    }
438
439    /// Apply optimization hints to a graph
440    pub fn apply_hints(
441        &self,
442        graph: &mut EinsumGraph,
443        hints: &[OptimizationHint],
444    ) -> Result<usize, IrError> {
445        let mut applied = 0;
446
447        for hint in hints {
448            match hint {
449                OptimizationHint::FuseNodes(nodes) => {
450                    if self.try_fuse_nodes(graph, nodes)? {
451                        applied += 1;
452                    }
453                }
454                OptimizationHint::CacheTensor(tensor_id) => {
455                    self.mark_tensor_cached(graph, *tensor_id);
456                    applied += 1;
457                }
458                OptimizationHint::InPlaceOp(node_id) => {
459                    if self.try_make_inplace(graph, *node_id)? {
460                        applied += 1;
461                    }
462                }
463                OptimizationHint::PreAllocate { tensor, size } => {
464                    self.mark_preallocate(graph, *tensor, *size);
465                    applied += 1;
466                }
467                _ => {
468                    // Other hints require backend-specific implementation
469                }
470            }
471        }
472
473        Ok(applied)
474    }
475
476    fn try_fuse_nodes(&self, _graph: &mut EinsumGraph, _nodes: &[NodeId]) -> Result<bool, IrError> {
477        // Would implement actual fusion logic
478        Ok(false)
479    }
480
481    fn mark_tensor_cached(&self, _graph: &mut EinsumGraph, _tensor_id: TensorId) {
482        // Mark tensor for caching (would add metadata)
483    }
484
485    fn try_make_inplace(
486        &self,
487        _graph: &mut EinsumGraph,
488        _node_id: NodeId,
489    ) -> Result<bool, IrError> {
490        // Would check if operation can be in-place and modify
491        Ok(false)
492    }
493
494    fn mark_preallocate(&self, _graph: &mut EinsumGraph, _tensor_id: TensorId, _size: usize) {
495        // Mark tensor for pre-allocation (would add metadata)
496    }
497
498    /// Get profile reference
499    pub fn profile(&self) -> &ExecutionProfile {
500        &self.profile
501    }
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507
508    #[test]
509    fn test_node_stats_basic() {
510        let mut stats = NodeStats::new();
511
512        stats.record_execution(Duration::from_millis(100), 1024);
513        assert_eq!(stats.execution_count, 1);
514        assert_eq!(stats.total_time, Duration::from_millis(100));
515        assert_eq!(stats.peak_memory, 1024);
516
517        stats.record_execution(Duration::from_millis(150), 2048);
518        assert_eq!(stats.execution_count, 2);
519        assert_eq!(stats.avg_time(), Duration::from_millis(125));
520        assert_eq!(stats.peak_memory, 2048);
521    }
522
523    #[test]
524    fn test_node_stats_min_max() {
525        let mut stats = NodeStats::new();
526
527        stats.record_execution(Duration::from_millis(100), 1024);
528        stats.record_execution(Duration::from_millis(50), 512);
529        stats.record_execution(Duration::from_millis(200), 4096);
530
531        assert_eq!(stats.min_time, Duration::from_millis(50));
532        assert_eq!(stats.max_time, Duration::from_millis(200));
533        assert_eq!(stats.time_variance(), Duration::from_millis(150));
534    }
535
536    #[test]
537    fn test_node_stats_hotness() {
538        let mut stats = NodeStats::new();
539
540        for _ in 0..5 {
541            stats.record_execution(Duration::from_millis(10), 100);
542        }
543
544        assert!(!stats.is_hot(10));
545        assert!(stats.is_hot(5));
546        assert!(stats.is_hot(1));
547    }
548
549    #[test]
550    fn test_execution_profile_record() {
551        let mut profile = ExecutionProfile::new();
552
553        profile.record_node(0, Duration::from_millis(100), 1024);
554        profile.record_node(1, Duration::from_millis(200), 2048);
555        profile.record_node(0, Duration::from_millis(110), 1024);
556
557        assert_eq!(profile.node_stats.len(), 2);
558        assert_eq!(profile.node_stats[&0].execution_count, 2);
559        assert_eq!(profile.node_stats[&1].execution_count, 1);
560    }
561
562    #[test]
563    fn test_hot_nodes() {
564        let mut profile = ExecutionProfile::new();
565
566        // Node 0: frequently executed, fast
567        for _ in 0..100 {
568            profile.record_node(0, Duration::from_millis(10), 100);
569        }
570
571        // Node 1: rarely executed, slow
572        for _ in 0..5 {
573            profile.record_node(1, Duration::from_millis(500), 10000);
574        }
575
576        let hot_nodes = profile.get_hot_nodes(2);
577        assert_eq!(hot_nodes.len(), 2);
578
579        // Node 1 should have higher performance score due to time*weight
580        // but this depends on the actual scoring function
581        assert!(hot_nodes[0].1 > 0.0);
582    }
583
584    #[test]
585    fn test_tensor_stats() {
586        let mut stats = TensorStats::new(1024);
587
588        assert_eq!(stats.access_count, 0);
589
590        stats.record_access();
591        assert_eq!(stats.access_count, 1);
592        assert_eq!(stats.last_access_time, 1);
593
594        stats.record_access();
595        assert_eq!(stats.access_count, 2);
596        assert_eq!(stats.last_access_time, 2);
597
598        assert!(stats.is_reused(2));
599        assert!(!stats.is_reused(3));
600    }
601
602    #[test]
603    fn test_profile_merge() {
604        let mut profile1 = ExecutionProfile::new();
605        profile1.record_node(0, Duration::from_millis(100), 1024);
606        profile1.total_executions = 1;
607
608        let mut profile2 = ExecutionProfile::new();
609        profile2.record_node(0, Duration::from_millis(150), 2048);
610        profile2.record_node(1, Duration::from_millis(200), 512);
611        profile2.total_executions = 1;
612
613        profile1.merge(&profile2);
614
615        assert_eq!(profile1.node_stats.len(), 2);
616        assert_eq!(profile1.node_stats[&0].execution_count, 2);
617        assert_eq!(profile1.total_executions, 2);
618    }
619
620    #[test]
621    fn test_profile_serialization() {
622        let mut profile = ExecutionProfile::new();
623        profile.record_node(0, Duration::from_millis(100), 1024);
624        profile.record_tensor_access(0, 2048);
625
626        let json = profile.to_json().unwrap();
627        let restored = ExecutionProfile::from_json(&json).unwrap();
628
629        assert_eq!(profile.node_stats.len(), restored.node_stats.len());
630        assert_eq!(profile.tensor_stats.len(), restored.tensor_stats.len());
631    }
632
633    #[test]
634    fn test_pgo_optimizer_basic() {
635        let mut profile = ExecutionProfile::new();
636
637        // Create hot nodes
638        for _ in 0..20 {
639            profile.record_node(0, Duration::from_millis(50), 1024);
640            profile.record_node(1, Duration::from_millis(60), 2048);
641        }
642
643        let optimizer = ProfileGuidedOptimizer::new(profile);
644        assert_eq!(optimizer.hot_threshold, 10);
645    }
646
647    #[test]
648    fn test_optimization_hints() {
649        let mut profile = ExecutionProfile::new();
650
651        // Hot nodes for fusion
652        for _ in 0..20 {
653            profile.record_node(0, Duration::from_millis(10), 1024);
654            profile.record_node(1, Duration::from_millis(10), 1024);
655        }
656
657        // Large memory node
658        profile.record_node(2, Duration::from_millis(100), 200 * 1024 * 1024);
659
660        // Frequently accessed tensor
661        for _ in 0..50 {
662            profile.record_tensor_access(0, 4096);
663        }
664
665        let optimizer = ProfileGuidedOptimizer::new(profile)
666            .with_hot_threshold(10)
667            .with_memory_threshold(100 * 1024 * 1024);
668
669        let graph = EinsumGraph::new();
670        let hints = optimizer.analyze(&graph);
671
672        // Should generate various hints
673        assert!(!hints.is_empty());
674
675        // Check for cache hint
676        assert!(hints
677            .iter()
678            .any(|h| matches!(h, OptimizationHint::CacheTensor(_))));
679    }
680
681    #[test]
682    fn test_memory_intensive_nodes() {
683        let mut profile = ExecutionProfile::new();
684
685        profile.record_node(0, Duration::from_millis(10), 50 * 1024 * 1024);
686        profile.record_node(1, Duration::from_millis(10), 150 * 1024 * 1024);
687        profile.record_node(2, Duration::from_millis(10), 1024);
688
689        let memory_nodes = profile.get_memory_intensive_nodes(100 * 1024 * 1024);
690
691        assert_eq!(memory_nodes.len(), 1);
692        assert_eq!(memory_nodes[0], 1);
693    }
694}