1use serde::{Deserialize, Serialize};
30use std::collections::HashMap;
31use std::time::Duration;
32
33use crate::graph::EinsumGraph;
34use crate::IrError;
35
36pub type NodeId = usize;
38pub type TensorId = usize;
40
41#[derive(Clone, Debug, Default, Serialize, Deserialize)]
43pub struct NodeStats {
44 pub execution_count: u64,
46 pub total_time: Duration,
48 pub min_time: Duration,
50 pub max_time: Duration,
52 pub memory_allocated: u64,
54 pub peak_memory: u64,
56 pub cache_misses: Option<u64>,
58 pub flops: Option<u64>,
60}
61
62impl NodeStats {
63 pub fn new() -> Self {
64 Self::default()
65 }
66
67 pub fn record_execution(&mut self, duration: Duration, memory: u64) {
69 self.execution_count += 1;
70 self.total_time += duration;
71
72 if self.execution_count == 1 {
73 self.min_time = duration;
74 self.max_time = duration;
75 } else {
76 if duration < self.min_time {
77 self.min_time = duration;
78 }
79 if duration > self.max_time {
80 self.max_time = duration;
81 }
82 }
83
84 self.memory_allocated += memory;
85 if memory > self.peak_memory {
86 self.peak_memory = memory;
87 }
88 }
89
90 pub fn avg_time(&self) -> Duration {
92 if self.execution_count > 0 {
93 self.total_time / self.execution_count as u32
94 } else {
95 Duration::ZERO
96 }
97 }
98
99 pub fn time_variance(&self) -> Duration {
101 self.max_time.saturating_sub(self.min_time)
102 }
103
104 pub fn is_hot(&self, threshold: u64) -> bool {
106 self.execution_count >= threshold
107 }
108
109 pub fn performance_score(&self) -> f64 {
111 let time_weight = self.total_time.as_secs_f64();
112 let memory_weight = self.peak_memory as f64 / 1_000_000.0; let execution_weight = self.execution_count as f64;
114
115 time_weight * 0.5 + memory_weight * 0.3 + execution_weight * 0.2
116 }
117}
118
119#[derive(Clone, Debug, Default, Serialize, Deserialize)]
121pub struct ExecutionProfile {
122 pub node_stats: HashMap<NodeId, NodeStats>,
124 pub tensor_stats: HashMap<TensorId, TensorStats>,
126 pub total_executions: u64,
128 pub critical_path: Vec<NodeId>,
130}
131
132impl ExecutionProfile {
133 pub fn new() -> Self {
134 Self::default()
135 }
136
137 pub fn record_node(&mut self, node_id: NodeId, duration: Duration, memory: u64) {
139 self.node_stats
140 .entry(node_id)
141 .or_default()
142 .record_execution(duration, memory);
143 }
144
145 pub fn record_tensor_access(&mut self, tensor_id: TensorId, size: usize) {
147 self.tensor_stats
148 .entry(tensor_id)
149 .or_insert_with(|| TensorStats::new(size))
150 .record_access();
151 }
152
153 pub fn get_hot_nodes(&self, n: usize) -> Vec<(NodeId, f64)> {
155 let mut scores: Vec<_> = self
156 .node_stats
157 .iter()
158 .map(|(id, stats)| (*id, stats.performance_score()))
159 .collect();
160
161 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
162 scores.truncate(n);
163 scores
164 }
165
166 pub fn get_memory_intensive_nodes(&self, threshold: u64) -> Vec<NodeId> {
168 self.node_stats
169 .iter()
170 .filter(|(_, stats)| stats.peak_memory >= threshold)
171 .map(|(id, _)| *id)
172 .collect()
173 }
174
175 pub fn merge(&mut self, other: &ExecutionProfile) {
177 for (node_id, other_stats) in &other.node_stats {
178 let stats = self.node_stats.entry(*node_id).or_default();
179
180 stats.execution_count += other_stats.execution_count;
181 stats.total_time += other_stats.total_time;
182 stats.memory_allocated += other_stats.memory_allocated;
183
184 if other_stats.min_time < stats.min_time
185 || stats.execution_count == other_stats.execution_count
186 {
187 stats.min_time = other_stats.min_time;
188 }
189 if other_stats.max_time > stats.max_time {
190 stats.max_time = other_stats.max_time;
191 }
192 if other_stats.peak_memory > stats.peak_memory {
193 stats.peak_memory = other_stats.peak_memory;
194 }
195 }
196
197 for (tensor_id, other_tensor_stats) in &other.tensor_stats {
198 let tensor_stats = self
199 .tensor_stats
200 .entry(*tensor_id)
201 .or_insert_with(|| TensorStats::new(other_tensor_stats.size_bytes));
202
203 tensor_stats.access_count += other_tensor_stats.access_count;
204 tensor_stats.last_access_time = tensor_stats
205 .last_access_time
206 .max(other_tensor_stats.last_access_time);
207 }
208
209 self.total_executions += other.total_executions;
210 }
211
212 pub fn to_json(&self) -> Result<String, IrError> {
214 serde_json::to_string_pretty(self).map_err(|e| IrError::SerializationError(e.to_string()))
215 }
216
217 pub fn from_json(json: &str) -> Result<Self, IrError> {
219 serde_json::from_str(json).map_err(|e| IrError::SerializationError(e.to_string()))
220 }
221}
222
223#[derive(Clone, Debug, Serialize, Deserialize)]
225pub struct TensorStats {
226 pub size_bytes: usize,
228 pub access_count: u64,
230 pub last_access_time: u64,
232}
233
234impl TensorStats {
235 pub fn new(size_bytes: usize) -> Self {
236 TensorStats {
237 size_bytes,
238 access_count: 0,
239 last_access_time: 0,
240 }
241 }
242
243 pub fn record_access(&mut self) {
244 self.access_count += 1;
245 self.last_access_time = self.access_count;
246 }
247
248 pub fn is_reused(&self, threshold: u64) -> bool {
250 self.access_count >= threshold
251 }
252}
253
254#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
256pub enum OptimizationHint {
257 FuseNodes(Vec<NodeId>),
259 Parallelize(Vec<NodeId>),
261 CacheTensor(TensorId),
263 InPlaceOp(NodeId),
265 Prefetch(TensorId),
267 TileOperation { node: NodeId, tile_size: usize },
269 ReorderOps(Vec<NodeId>),
271 PreAllocate { tensor: TensorId, size: usize },
273}
274
275#[derive(Clone, Debug)]
277pub struct ProfileGuidedOptimizer {
278 profile: ExecutionProfile,
279 hot_threshold: u64,
281 memory_threshold: u64,
283}
284
285impl ProfileGuidedOptimizer {
286 pub fn new(profile: ExecutionProfile) -> Self {
287 ProfileGuidedOptimizer {
288 profile,
289 hot_threshold: 10,
290 memory_threshold: 100 * 1024 * 1024, }
292 }
293
294 pub fn with_hot_threshold(mut self, threshold: u64) -> Self {
296 self.hot_threshold = threshold;
297 self
298 }
299
300 pub fn with_memory_threshold(mut self, threshold: u64) -> Self {
302 self.memory_threshold = threshold;
303 self
304 }
305
306 pub fn analyze(&self, graph: &EinsumGraph) -> Vec<OptimizationHint> {
308 let mut hints = Vec::new();
309
310 let hot_nodes = self.profile.get_hot_nodes(10);
312 if hot_nodes.len() >= 2 {
313 let node_ids: Vec<_> = hot_nodes.iter().map(|(id, _)| *id).collect();
314 hints.push(OptimizationHint::FuseNodes(node_ids));
315 }
316
317 let memory_nodes = self
319 .profile
320 .get_memory_intensive_nodes(self.memory_threshold);
321 for node_id in memory_nodes {
322 hints.push(OptimizationHint::InPlaceOp(node_id));
324
325 if self.is_tileable(node_id, graph) {
327 hints.push(OptimizationHint::TileOperation {
328 node: node_id,
329 tile_size: 1024, });
331 }
332 }
333
334 for (tensor_id, stats) in &self.profile.tensor_stats {
336 if stats.is_reused(self.hot_threshold) {
337 hints.push(OptimizationHint::CacheTensor(*tensor_id));
338 }
339
340 if stats.size_bytes > 1024 * 1024 {
342 hints.push(OptimizationHint::PreAllocate {
344 tensor: *tensor_id,
345 size: stats.size_bytes,
346 });
347 }
348 }
349
350 let parallel_groups = self.find_parallel_groups(graph);
352 for group in parallel_groups {
353 if group.len() >= 2 {
354 hints.push(OptimizationHint::Parallelize(group));
355 }
356 }
357
358 hints
359 }
360
361 fn is_tileable(&self, _node_id: NodeId, _graph: &EinsumGraph) -> bool {
363 true
365 }
366
367 fn find_parallel_groups(&self, graph: &EinsumGraph) -> Vec<Vec<NodeId>> {
369 let mut groups = Vec::new();
370
371 let depths = self.compute_depths(graph);
373 let mut depth_map: HashMap<usize, Vec<NodeId>> = HashMap::new();
374
375 for (node_id, depth) in depths {
376 depth_map.entry(depth).or_default().push(node_id);
377 }
378
379 for (_, nodes) in depth_map {
380 if nodes.len() >= 2 {
381 groups.push(nodes);
382 }
383 }
384
385 groups
386 }
387
388 fn compute_depths(&self, graph: &EinsumGraph) -> HashMap<NodeId, usize> {
390 let mut depths = HashMap::new();
391
392 for node_id in 0..graph.nodes.len() {
393 depths.insert(
394 node_id,
395 self.compute_node_depth(node_id, graph, &mut HashMap::new()),
396 );
397 }
398
399 depths
400 }
401
402 #[allow(clippy::only_used_in_recursion)]
403 fn compute_node_depth(
404 &self,
405 node_id: NodeId,
406 graph: &EinsumGraph,
407 memo: &mut HashMap<NodeId, usize>,
408 ) -> usize {
409 if let Some(&depth) = memo.get(&node_id) {
410 return depth;
411 }
412
413 let node = &graph.nodes[node_id];
414 let input_depths: Vec<_> = node
415 .inputs
416 .iter()
417 .filter_map(|&tensor_id| {
418 graph.nodes.iter().enumerate().find_map(|(id, n)| {
420 if n.outputs.contains(&tensor_id) {
421 Some(self.compute_node_depth(id, graph, memo))
422 } else {
423 None
424 }
425 })
426 })
427 .collect();
428
429 let depth = if input_depths.is_empty() {
430 0
431 } else {
432 input_depths.into_iter().max().unwrap_or(0) + 1
433 };
434
435 memo.insert(node_id, depth);
436 depth
437 }
438
439 pub fn apply_hints(
441 &self,
442 graph: &mut EinsumGraph,
443 hints: &[OptimizationHint],
444 ) -> Result<usize, IrError> {
445 let mut applied = 0;
446
447 for hint in hints {
448 match hint {
449 OptimizationHint::FuseNodes(nodes) if self.try_fuse_nodes(graph, nodes)? => {
450 applied += 1;
451 }
452 OptimizationHint::CacheTensor(tensor_id) => {
453 self.mark_tensor_cached(graph, *tensor_id);
454 applied += 1;
455 }
456 OptimizationHint::InPlaceOp(node_id)
457 if self.try_make_inplace(graph, *node_id)? =>
458 {
459 applied += 1;
460 }
461 OptimizationHint::PreAllocate { tensor, size } => {
462 self.mark_preallocate(graph, *tensor, *size);
463 applied += 1;
464 }
465 _ => {
466 }
468 }
469 }
470
471 Ok(applied)
472 }
473
474 fn try_fuse_nodes(&self, _graph: &mut EinsumGraph, _nodes: &[NodeId]) -> Result<bool, IrError> {
475 Ok(false)
477 }
478
479 fn mark_tensor_cached(&self, _graph: &mut EinsumGraph, _tensor_id: TensorId) {
480 }
482
483 fn try_make_inplace(
484 &self,
485 _graph: &mut EinsumGraph,
486 _node_id: NodeId,
487 ) -> Result<bool, IrError> {
488 Ok(false)
490 }
491
492 fn mark_preallocate(&self, _graph: &mut EinsumGraph, _tensor_id: TensorId, _size: usize) {
493 }
495
496 pub fn profile(&self) -> &ExecutionProfile {
498 &self.profile
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505
506 #[test]
507 fn test_node_stats_basic() {
508 let mut stats = NodeStats::new();
509
510 stats.record_execution(Duration::from_millis(100), 1024);
511 assert_eq!(stats.execution_count, 1);
512 assert_eq!(stats.total_time, Duration::from_millis(100));
513 assert_eq!(stats.peak_memory, 1024);
514
515 stats.record_execution(Duration::from_millis(150), 2048);
516 assert_eq!(stats.execution_count, 2);
517 assert_eq!(stats.avg_time(), Duration::from_millis(125));
518 assert_eq!(stats.peak_memory, 2048);
519 }
520
521 #[test]
522 fn test_node_stats_min_max() {
523 let mut stats = NodeStats::new();
524
525 stats.record_execution(Duration::from_millis(100), 1024);
526 stats.record_execution(Duration::from_millis(50), 512);
527 stats.record_execution(Duration::from_millis(200), 4096);
528
529 assert_eq!(stats.min_time, Duration::from_millis(50));
530 assert_eq!(stats.max_time, Duration::from_millis(200));
531 assert_eq!(stats.time_variance(), Duration::from_millis(150));
532 }
533
534 #[test]
535 fn test_node_stats_hotness() {
536 let mut stats = NodeStats::new();
537
538 for _ in 0..5 {
539 stats.record_execution(Duration::from_millis(10), 100);
540 }
541
542 assert!(!stats.is_hot(10));
543 assert!(stats.is_hot(5));
544 assert!(stats.is_hot(1));
545 }
546
547 #[test]
548 fn test_execution_profile_record() {
549 let mut profile = ExecutionProfile::new();
550
551 profile.record_node(0, Duration::from_millis(100), 1024);
552 profile.record_node(1, Duration::from_millis(200), 2048);
553 profile.record_node(0, Duration::from_millis(110), 1024);
554
555 assert_eq!(profile.node_stats.len(), 2);
556 assert_eq!(profile.node_stats[&0].execution_count, 2);
557 assert_eq!(profile.node_stats[&1].execution_count, 1);
558 }
559
560 #[test]
561 fn test_hot_nodes() {
562 let mut profile = ExecutionProfile::new();
563
564 for _ in 0..100 {
566 profile.record_node(0, Duration::from_millis(10), 100);
567 }
568
569 for _ in 0..5 {
571 profile.record_node(1, Duration::from_millis(500), 10000);
572 }
573
574 let hot_nodes = profile.get_hot_nodes(2);
575 assert_eq!(hot_nodes.len(), 2);
576
577 assert!(hot_nodes[0].1 > 0.0);
580 }
581
582 #[test]
583 fn test_tensor_stats() {
584 let mut stats = TensorStats::new(1024);
585
586 assert_eq!(stats.access_count, 0);
587
588 stats.record_access();
589 assert_eq!(stats.access_count, 1);
590 assert_eq!(stats.last_access_time, 1);
591
592 stats.record_access();
593 assert_eq!(stats.access_count, 2);
594 assert_eq!(stats.last_access_time, 2);
595
596 assert!(stats.is_reused(2));
597 assert!(!stats.is_reused(3));
598 }
599
600 #[test]
601 fn test_profile_merge() {
602 let mut profile1 = ExecutionProfile::new();
603 profile1.record_node(0, Duration::from_millis(100), 1024);
604 profile1.total_executions = 1;
605
606 let mut profile2 = ExecutionProfile::new();
607 profile2.record_node(0, Duration::from_millis(150), 2048);
608 profile2.record_node(1, Duration::from_millis(200), 512);
609 profile2.total_executions = 1;
610
611 profile1.merge(&profile2);
612
613 assert_eq!(profile1.node_stats.len(), 2);
614 assert_eq!(profile1.node_stats[&0].execution_count, 2);
615 assert_eq!(profile1.total_executions, 2);
616 }
617
618 #[test]
619 fn test_profile_serialization() {
620 let mut profile = ExecutionProfile::new();
621 profile.record_node(0, Duration::from_millis(100), 1024);
622 profile.record_tensor_access(0, 2048);
623
624 let json = profile.to_json().expect("unwrap");
625 let restored = ExecutionProfile::from_json(&json).expect("unwrap");
626
627 assert_eq!(profile.node_stats.len(), restored.node_stats.len());
628 assert_eq!(profile.tensor_stats.len(), restored.tensor_stats.len());
629 }
630
631 #[test]
632 fn test_pgo_optimizer_basic() {
633 let mut profile = ExecutionProfile::new();
634
635 for _ in 0..20 {
637 profile.record_node(0, Duration::from_millis(50), 1024);
638 profile.record_node(1, Duration::from_millis(60), 2048);
639 }
640
641 let optimizer = ProfileGuidedOptimizer::new(profile);
642 assert_eq!(optimizer.hot_threshold, 10);
643 }
644
645 #[test]
646 fn test_optimization_hints() {
647 let mut profile = ExecutionProfile::new();
648
649 for _ in 0..20 {
651 profile.record_node(0, Duration::from_millis(10), 1024);
652 profile.record_node(1, Duration::from_millis(10), 1024);
653 }
654
655 profile.record_node(2, Duration::from_millis(100), 200 * 1024 * 1024);
657
658 for _ in 0..50 {
660 profile.record_tensor_access(0, 4096);
661 }
662
663 let optimizer = ProfileGuidedOptimizer::new(profile)
664 .with_hot_threshold(10)
665 .with_memory_threshold(100 * 1024 * 1024);
666
667 let graph = EinsumGraph::new();
668 let hints = optimizer.analyze(&graph);
669
670 assert!(!hints.is_empty());
672
673 assert!(hints
675 .iter()
676 .any(|h| matches!(h, OptimizationHint::CacheTensor(_))));
677 }
678
679 #[test]
680 fn test_memory_intensive_nodes() {
681 let mut profile = ExecutionProfile::new();
682
683 profile.record_node(0, Duration::from_millis(10), 50 * 1024 * 1024);
684 profile.record_node(1, Duration::from_millis(10), 150 * 1024 * 1024);
685 profile.record_node(2, Duration::from_millis(10), 1024);
686
687 let memory_nodes = profile.get_memory_intensive_nodes(100 * 1024 * 1024);
688
689 assert_eq!(memory_nodes.len(), 1);
690 assert_eq!(memory_nodes[0], 1);
691 }
692}