1use serde::{Deserialize, Serialize};
30use std::collections::HashMap;
31use std::time::Duration;
32
33use crate::graph::EinsumGraph;
34use crate::IrError;
35
36pub type NodeId = usize;
38pub type TensorId = usize;
40
41#[derive(Clone, Debug, Default, Serialize, Deserialize)]
43pub struct NodeStats {
44 pub execution_count: u64,
46 pub total_time: Duration,
48 pub min_time: Duration,
50 pub max_time: Duration,
52 pub memory_allocated: u64,
54 pub peak_memory: u64,
56 pub cache_misses: Option<u64>,
58 pub flops: Option<u64>,
60}
61
62impl NodeStats {
63 pub fn new() -> Self {
64 Self::default()
65 }
66
67 pub fn record_execution(&mut self, duration: Duration, memory: u64) {
69 self.execution_count += 1;
70 self.total_time += duration;
71
72 if self.execution_count == 1 {
73 self.min_time = duration;
74 self.max_time = duration;
75 } else {
76 if duration < self.min_time {
77 self.min_time = duration;
78 }
79 if duration > self.max_time {
80 self.max_time = duration;
81 }
82 }
83
84 self.memory_allocated += memory;
85 if memory > self.peak_memory {
86 self.peak_memory = memory;
87 }
88 }
89
90 pub fn avg_time(&self) -> Duration {
92 if self.execution_count > 0 {
93 self.total_time / self.execution_count as u32
94 } else {
95 Duration::ZERO
96 }
97 }
98
99 pub fn time_variance(&self) -> Duration {
101 self.max_time.saturating_sub(self.min_time)
102 }
103
104 pub fn is_hot(&self, threshold: u64) -> bool {
106 self.execution_count >= threshold
107 }
108
109 pub fn performance_score(&self) -> f64 {
111 let time_weight = self.total_time.as_secs_f64();
112 let memory_weight = self.peak_memory as f64 / 1_000_000.0; let execution_weight = self.execution_count as f64;
114
115 time_weight * 0.5 + memory_weight * 0.3 + execution_weight * 0.2
116 }
117}
118
119#[derive(Clone, Debug, Default, Serialize, Deserialize)]
121pub struct ExecutionProfile {
122 pub node_stats: HashMap<NodeId, NodeStats>,
124 pub tensor_stats: HashMap<TensorId, TensorStats>,
126 pub total_executions: u64,
128 pub critical_path: Vec<NodeId>,
130}
131
132impl ExecutionProfile {
133 pub fn new() -> Self {
134 Self::default()
135 }
136
137 pub fn record_node(&mut self, node_id: NodeId, duration: Duration, memory: u64) {
139 self.node_stats
140 .entry(node_id)
141 .or_default()
142 .record_execution(duration, memory);
143 }
144
145 pub fn record_tensor_access(&mut self, tensor_id: TensorId, size: usize) {
147 self.tensor_stats
148 .entry(tensor_id)
149 .or_insert_with(|| TensorStats::new(size))
150 .record_access();
151 }
152
153 pub fn get_hot_nodes(&self, n: usize) -> Vec<(NodeId, f64)> {
155 let mut scores: Vec<_> = self
156 .node_stats
157 .iter()
158 .map(|(id, stats)| (*id, stats.performance_score()))
159 .collect();
160
161 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
162 scores.truncate(n);
163 scores
164 }
165
166 pub fn get_memory_intensive_nodes(&self, threshold: u64) -> Vec<NodeId> {
168 self.node_stats
169 .iter()
170 .filter(|(_, stats)| stats.peak_memory >= threshold)
171 .map(|(id, _)| *id)
172 .collect()
173 }
174
175 pub fn merge(&mut self, other: &ExecutionProfile) {
177 for (node_id, other_stats) in &other.node_stats {
178 let stats = self.node_stats.entry(*node_id).or_default();
179
180 stats.execution_count += other_stats.execution_count;
181 stats.total_time += other_stats.total_time;
182 stats.memory_allocated += other_stats.memory_allocated;
183
184 if other_stats.min_time < stats.min_time
185 || stats.execution_count == other_stats.execution_count
186 {
187 stats.min_time = other_stats.min_time;
188 }
189 if other_stats.max_time > stats.max_time {
190 stats.max_time = other_stats.max_time;
191 }
192 if other_stats.peak_memory > stats.peak_memory {
193 stats.peak_memory = other_stats.peak_memory;
194 }
195 }
196
197 for (tensor_id, other_tensor_stats) in &other.tensor_stats {
198 let tensor_stats = self
199 .tensor_stats
200 .entry(*tensor_id)
201 .or_insert_with(|| TensorStats::new(other_tensor_stats.size_bytes));
202
203 tensor_stats.access_count += other_tensor_stats.access_count;
204 tensor_stats.last_access_time = tensor_stats
205 .last_access_time
206 .max(other_tensor_stats.last_access_time);
207 }
208
209 self.total_executions += other.total_executions;
210 }
211
212 pub fn to_json(&self) -> Result<String, IrError> {
214 serde_json::to_string_pretty(self).map_err(|e| IrError::SerializationError(e.to_string()))
215 }
216
217 pub fn from_json(json: &str) -> Result<Self, IrError> {
219 serde_json::from_str(json).map_err(|e| IrError::SerializationError(e.to_string()))
220 }
221}
222
223#[derive(Clone, Debug, Serialize, Deserialize)]
225pub struct TensorStats {
226 pub size_bytes: usize,
228 pub access_count: u64,
230 pub last_access_time: u64,
232}
233
234impl TensorStats {
235 pub fn new(size_bytes: usize) -> Self {
236 TensorStats {
237 size_bytes,
238 access_count: 0,
239 last_access_time: 0,
240 }
241 }
242
243 pub fn record_access(&mut self) {
244 self.access_count += 1;
245 self.last_access_time = self.access_count;
246 }
247
248 pub fn is_reused(&self, threshold: u64) -> bool {
250 self.access_count >= threshold
251 }
252}
253
254#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
256pub enum OptimizationHint {
257 FuseNodes(Vec<NodeId>),
259 Parallelize(Vec<NodeId>),
261 CacheTensor(TensorId),
263 InPlaceOp(NodeId),
265 Prefetch(TensorId),
267 TileOperation { node: NodeId, tile_size: usize },
269 ReorderOps(Vec<NodeId>),
271 PreAllocate { tensor: TensorId, size: usize },
273}
274
275#[derive(Clone, Debug)]
277pub struct ProfileGuidedOptimizer {
278 profile: ExecutionProfile,
279 hot_threshold: u64,
281 memory_threshold: u64,
283}
284
285impl ProfileGuidedOptimizer {
286 pub fn new(profile: ExecutionProfile) -> Self {
287 ProfileGuidedOptimizer {
288 profile,
289 hot_threshold: 10,
290 memory_threshold: 100 * 1024 * 1024, }
292 }
293
294 pub fn with_hot_threshold(mut self, threshold: u64) -> Self {
296 self.hot_threshold = threshold;
297 self
298 }
299
300 pub fn with_memory_threshold(mut self, threshold: u64) -> Self {
302 self.memory_threshold = threshold;
303 self
304 }
305
306 pub fn analyze(&self, graph: &EinsumGraph) -> Vec<OptimizationHint> {
308 let mut hints = Vec::new();
309
310 let hot_nodes = self.profile.get_hot_nodes(10);
312 if hot_nodes.len() >= 2 {
313 let node_ids: Vec<_> = hot_nodes.iter().map(|(id, _)| *id).collect();
314 hints.push(OptimizationHint::FuseNodes(node_ids));
315 }
316
317 let memory_nodes = self
319 .profile
320 .get_memory_intensive_nodes(self.memory_threshold);
321 for node_id in memory_nodes {
322 hints.push(OptimizationHint::InPlaceOp(node_id));
324
325 if self.is_tileable(node_id, graph) {
327 hints.push(OptimizationHint::TileOperation {
328 node: node_id,
329 tile_size: 1024, });
331 }
332 }
333
334 for (tensor_id, stats) in &self.profile.tensor_stats {
336 if stats.is_reused(self.hot_threshold) {
337 hints.push(OptimizationHint::CacheTensor(*tensor_id));
338 }
339
340 if stats.size_bytes > 1024 * 1024 {
342 hints.push(OptimizationHint::PreAllocate {
344 tensor: *tensor_id,
345 size: stats.size_bytes,
346 });
347 }
348 }
349
350 let parallel_groups = self.find_parallel_groups(graph);
352 for group in parallel_groups {
353 if group.len() >= 2 {
354 hints.push(OptimizationHint::Parallelize(group));
355 }
356 }
357
358 hints
359 }
360
361 fn is_tileable(&self, _node_id: NodeId, _graph: &EinsumGraph) -> bool {
363 true
365 }
366
367 fn find_parallel_groups(&self, graph: &EinsumGraph) -> Vec<Vec<NodeId>> {
369 let mut groups = Vec::new();
370
371 let depths = self.compute_depths(graph);
373 let mut depth_map: HashMap<usize, Vec<NodeId>> = HashMap::new();
374
375 for (node_id, depth) in depths {
376 depth_map.entry(depth).or_default().push(node_id);
377 }
378
379 for (_, nodes) in depth_map {
380 if nodes.len() >= 2 {
381 groups.push(nodes);
382 }
383 }
384
385 groups
386 }
387
388 fn compute_depths(&self, graph: &EinsumGraph) -> HashMap<NodeId, usize> {
390 let mut depths = HashMap::new();
391
392 for node_id in 0..graph.nodes.len() {
393 depths.insert(
394 node_id,
395 self.compute_node_depth(node_id, graph, &mut HashMap::new()),
396 );
397 }
398
399 depths
400 }
401
402 #[allow(clippy::only_used_in_recursion)]
403 fn compute_node_depth(
404 &self,
405 node_id: NodeId,
406 graph: &EinsumGraph,
407 memo: &mut HashMap<NodeId, usize>,
408 ) -> usize {
409 if let Some(&depth) = memo.get(&node_id) {
410 return depth;
411 }
412
413 let node = &graph.nodes[node_id];
414 let input_depths: Vec<_> = node
415 .inputs
416 .iter()
417 .filter_map(|&tensor_id| {
418 graph.nodes.iter().enumerate().find_map(|(id, n)| {
420 if n.outputs.contains(&tensor_id) {
421 Some(self.compute_node_depth(id, graph, memo))
422 } else {
423 None
424 }
425 })
426 })
427 .collect();
428
429 let depth = if input_depths.is_empty() {
430 0
431 } else {
432 input_depths.into_iter().max().unwrap() + 1
433 };
434
435 memo.insert(node_id, depth);
436 depth
437 }
438
439 pub fn apply_hints(
441 &self,
442 graph: &mut EinsumGraph,
443 hints: &[OptimizationHint],
444 ) -> Result<usize, IrError> {
445 let mut applied = 0;
446
447 for hint in hints {
448 match hint {
449 OptimizationHint::FuseNodes(nodes) => {
450 if self.try_fuse_nodes(graph, nodes)? {
451 applied += 1;
452 }
453 }
454 OptimizationHint::CacheTensor(tensor_id) => {
455 self.mark_tensor_cached(graph, *tensor_id);
456 applied += 1;
457 }
458 OptimizationHint::InPlaceOp(node_id) => {
459 if self.try_make_inplace(graph, *node_id)? {
460 applied += 1;
461 }
462 }
463 OptimizationHint::PreAllocate { tensor, size } => {
464 self.mark_preallocate(graph, *tensor, *size);
465 applied += 1;
466 }
467 _ => {
468 }
470 }
471 }
472
473 Ok(applied)
474 }
475
476 fn try_fuse_nodes(&self, _graph: &mut EinsumGraph, _nodes: &[NodeId]) -> Result<bool, IrError> {
477 Ok(false)
479 }
480
481 fn mark_tensor_cached(&self, _graph: &mut EinsumGraph, _tensor_id: TensorId) {
482 }
484
485 fn try_make_inplace(
486 &self,
487 _graph: &mut EinsumGraph,
488 _node_id: NodeId,
489 ) -> Result<bool, IrError> {
490 Ok(false)
492 }
493
494 fn mark_preallocate(&self, _graph: &mut EinsumGraph, _tensor_id: TensorId, _size: usize) {
495 }
497
498 pub fn profile(&self) -> &ExecutionProfile {
500 &self.profile
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507
508 #[test]
509 fn test_node_stats_basic() {
510 let mut stats = NodeStats::new();
511
512 stats.record_execution(Duration::from_millis(100), 1024);
513 assert_eq!(stats.execution_count, 1);
514 assert_eq!(stats.total_time, Duration::from_millis(100));
515 assert_eq!(stats.peak_memory, 1024);
516
517 stats.record_execution(Duration::from_millis(150), 2048);
518 assert_eq!(stats.execution_count, 2);
519 assert_eq!(stats.avg_time(), Duration::from_millis(125));
520 assert_eq!(stats.peak_memory, 2048);
521 }
522
523 #[test]
524 fn test_node_stats_min_max() {
525 let mut stats = NodeStats::new();
526
527 stats.record_execution(Duration::from_millis(100), 1024);
528 stats.record_execution(Duration::from_millis(50), 512);
529 stats.record_execution(Duration::from_millis(200), 4096);
530
531 assert_eq!(stats.min_time, Duration::from_millis(50));
532 assert_eq!(stats.max_time, Duration::from_millis(200));
533 assert_eq!(stats.time_variance(), Duration::from_millis(150));
534 }
535
536 #[test]
537 fn test_node_stats_hotness() {
538 let mut stats = NodeStats::new();
539
540 for _ in 0..5 {
541 stats.record_execution(Duration::from_millis(10), 100);
542 }
543
544 assert!(!stats.is_hot(10));
545 assert!(stats.is_hot(5));
546 assert!(stats.is_hot(1));
547 }
548
549 #[test]
550 fn test_execution_profile_record() {
551 let mut profile = ExecutionProfile::new();
552
553 profile.record_node(0, Duration::from_millis(100), 1024);
554 profile.record_node(1, Duration::from_millis(200), 2048);
555 profile.record_node(0, Duration::from_millis(110), 1024);
556
557 assert_eq!(profile.node_stats.len(), 2);
558 assert_eq!(profile.node_stats[&0].execution_count, 2);
559 assert_eq!(profile.node_stats[&1].execution_count, 1);
560 }
561
562 #[test]
563 fn test_hot_nodes() {
564 let mut profile = ExecutionProfile::new();
565
566 for _ in 0..100 {
568 profile.record_node(0, Duration::from_millis(10), 100);
569 }
570
571 for _ in 0..5 {
573 profile.record_node(1, Duration::from_millis(500), 10000);
574 }
575
576 let hot_nodes = profile.get_hot_nodes(2);
577 assert_eq!(hot_nodes.len(), 2);
578
579 assert!(hot_nodes[0].1 > 0.0);
582 }
583
584 #[test]
585 fn test_tensor_stats() {
586 let mut stats = TensorStats::new(1024);
587
588 assert_eq!(stats.access_count, 0);
589
590 stats.record_access();
591 assert_eq!(stats.access_count, 1);
592 assert_eq!(stats.last_access_time, 1);
593
594 stats.record_access();
595 assert_eq!(stats.access_count, 2);
596 assert_eq!(stats.last_access_time, 2);
597
598 assert!(stats.is_reused(2));
599 assert!(!stats.is_reused(3));
600 }
601
602 #[test]
603 fn test_profile_merge() {
604 let mut profile1 = ExecutionProfile::new();
605 profile1.record_node(0, Duration::from_millis(100), 1024);
606 profile1.total_executions = 1;
607
608 let mut profile2 = ExecutionProfile::new();
609 profile2.record_node(0, Duration::from_millis(150), 2048);
610 profile2.record_node(1, Duration::from_millis(200), 512);
611 profile2.total_executions = 1;
612
613 profile1.merge(&profile2);
614
615 assert_eq!(profile1.node_stats.len(), 2);
616 assert_eq!(profile1.node_stats[&0].execution_count, 2);
617 assert_eq!(profile1.total_executions, 2);
618 }
619
620 #[test]
621 fn test_profile_serialization() {
622 let mut profile = ExecutionProfile::new();
623 profile.record_node(0, Duration::from_millis(100), 1024);
624 profile.record_tensor_access(0, 2048);
625
626 let json = profile.to_json().unwrap();
627 let restored = ExecutionProfile::from_json(&json).unwrap();
628
629 assert_eq!(profile.node_stats.len(), restored.node_stats.len());
630 assert_eq!(profile.tensor_stats.len(), restored.tensor_stats.len());
631 }
632
633 #[test]
634 fn test_pgo_optimizer_basic() {
635 let mut profile = ExecutionProfile::new();
636
637 for _ in 0..20 {
639 profile.record_node(0, Duration::from_millis(50), 1024);
640 profile.record_node(1, Duration::from_millis(60), 2048);
641 }
642
643 let optimizer = ProfileGuidedOptimizer::new(profile);
644 assert_eq!(optimizer.hot_threshold, 10);
645 }
646
647 #[test]
648 fn test_optimization_hints() {
649 let mut profile = ExecutionProfile::new();
650
651 for _ in 0..20 {
653 profile.record_node(0, Duration::from_millis(10), 1024);
654 profile.record_node(1, Duration::from_millis(10), 1024);
655 }
656
657 profile.record_node(2, Duration::from_millis(100), 200 * 1024 * 1024);
659
660 for _ in 0..50 {
662 profile.record_tensor_access(0, 4096);
663 }
664
665 let optimizer = ProfileGuidedOptimizer::new(profile)
666 .with_hot_threshold(10)
667 .with_memory_threshold(100 * 1024 * 1024);
668
669 let graph = EinsumGraph::new();
670 let hints = optimizer.analyze(&graph);
671
672 assert!(!hints.is_empty());
674
675 assert!(hints
677 .iter()
678 .any(|h| matches!(h, OptimizationHint::CacheTensor(_))));
679 }
680
681 #[test]
682 fn test_memory_intensive_nodes() {
683 let mut profile = ExecutionProfile::new();
684
685 profile.record_node(0, Duration::from_millis(10), 50 * 1024 * 1024);
686 profile.record_node(1, Duration::from_millis(10), 150 * 1024 * 1024);
687 profile.record_node(2, Duration::from_millis(10), 1024);
688
689 let memory_nodes = profile.get_memory_intensive_nodes(100 * 1024 * 1024);
690
691 assert_eq!(memory_nodes.len(), 1);
692 assert_eq!(memory_nodes[0], 1);
693 }
694}