1use std::cmp::Reverse;
10use std::collections::{HashMap, HashSet};
11use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct FusionOpportunity {
16 pub producer_idx: usize,
17 pub consumer_idx: usize,
18 pub fusion_type: FusionType,
19 pub estimated_speedup: u32, }
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum FusionType {
25 ElementWise,
27 ReductionElementWise,
29 MultiReduction,
31 EinsumChain,
33}
34
35#[derive(Debug, Clone)]
37pub struct OptimizationResult {
38 pub fusion_opportunities: Vec<FusionOpportunity>,
39 pub dead_nodes: Vec<usize>,
40 pub redundant_computations: Vec<(usize, usize)>, pub estimated_improvement: f64, }
43
44impl OptimizationResult {
45 pub fn new() -> Self {
46 OptimizationResult {
47 fusion_opportunities: Vec::new(),
48 dead_nodes: Vec::new(),
49 redundant_computations: Vec::new(),
50 estimated_improvement: 0.0,
51 }
52 }
53
54 pub fn is_empty(&self) -> bool {
55 self.fusion_opportunities.is_empty()
56 && self.dead_nodes.is_empty()
57 && self.redundant_computations.is_empty()
58 }
59
60 pub fn total_opportunities(&self) -> usize {
61 self.fusion_opportunities.len() + self.dead_nodes.len() + self.redundant_computations.len()
62 }
63}
64
65impl Default for OptimizationResult {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71pub struct GraphOptimizer {
73 enable_fusion: bool,
74 enable_dead_node_elimination: bool,
75 enable_redundancy_detection: bool,
76 min_fusion_benefit: u32,
77}
78
79impl GraphOptimizer {
80 pub fn new() -> Self {
81 GraphOptimizer {
82 enable_fusion: true,
83 enable_dead_node_elimination: true,
84 enable_redundancy_detection: true,
85 min_fusion_benefit: 10, }
87 }
88
89 pub fn with_fusion(mut self, enabled: bool) -> Self {
90 self.enable_fusion = enabled;
91 self
92 }
93
94 pub fn with_dead_node_elimination(mut self, enabled: bool) -> Self {
95 self.enable_dead_node_elimination = enabled;
96 self
97 }
98
99 pub fn with_redundancy_detection(mut self, enabled: bool) -> Self {
100 self.enable_redundancy_detection = enabled;
101 self
102 }
103
104 pub fn with_min_fusion_benefit(mut self, min_benefit: u32) -> Self {
105 self.min_fusion_benefit = min_benefit;
106 self
107 }
108
109 pub fn analyze(&self, graph: &EinsumGraph) -> OptimizationResult {
111 let mut result = OptimizationResult::new();
112
113 let tensor_producers = self.build_producer_map(graph);
115 let tensor_consumers = self.build_consumer_map(graph);
116
117 if self.enable_fusion {
119 result.fusion_opportunities =
120 self.detect_fusion_opportunities(graph, &tensor_producers, &tensor_consumers);
121 }
122
123 if self.enable_dead_node_elimination {
125 result.dead_nodes = self.detect_dead_nodes(graph, &tensor_consumers);
126 }
127
128 if self.enable_redundancy_detection {
130 result.redundant_computations = self.detect_redundant_computations(graph);
131 }
132
133 result.estimated_improvement = self.estimate_improvement(&result);
135
136 result
137 }
138
139 fn build_producer_map(&self, graph: &EinsumGraph) -> HashMap<usize, usize> {
141 let mut producers = HashMap::new();
142 for (node_idx, node) in graph.nodes.iter().enumerate() {
143 for &output_idx in &node.outputs {
144 producers.insert(output_idx, node_idx);
145 }
146 }
147 producers
148 }
149
150 fn build_consumer_map(&self, graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
152 let mut consumers: HashMap<usize, Vec<usize>> = HashMap::new();
153 for (node_idx, node) in graph.nodes.iter().enumerate() {
154 for &input_idx in &node.inputs {
155 consumers.entry(input_idx).or_default().push(node_idx);
156 }
157 }
158 consumers
159 }
160
161 fn detect_fusion_opportunities(
163 &self,
164 graph: &EinsumGraph,
165 tensor_producers: &HashMap<usize, usize>,
166 tensor_consumers: &HashMap<usize, Vec<usize>>,
167 ) -> Vec<FusionOpportunity> {
168 let mut opportunities = Vec::new();
169
170 for (node_idx, node) in graph.nodes.iter().enumerate() {
171 for &input_idx in &node.inputs {
173 if let Some(&producer_idx) = tensor_producers.get(&input_idx) {
174 let is_single_use = tensor_consumers
176 .get(&input_idx)
177 .map(|consumers| consumers.len() == 1)
178 .unwrap_or(false);
179
180 if is_single_use {
181 if let Some(fusion_type) = self.can_fuse(&graph.nodes[producer_idx], node) {
182 let estimated_speedup = self.estimate_fusion_speedup(fusion_type);
183 if estimated_speedup >= self.min_fusion_benefit {
184 opportunities.push(FusionOpportunity {
185 producer_idx,
186 consumer_idx: node_idx,
187 fusion_type,
188 estimated_speedup,
189 });
190 }
191 }
192 }
193 }
194 }
195 }
196
197 opportunities
198 }
199
200 fn can_fuse(&self, producer: &EinsumNode, consumer: &EinsumNode) -> Option<FusionType> {
202 match (&producer.op, &consumer.op) {
203 (OpType::ElemUnary { .. }, OpType::ElemUnary { .. })
205 | (OpType::ElemUnary { .. }, OpType::ElemBinary { .. })
206 | (OpType::ElemBinary { .. }, OpType::ElemUnary { .. })
207 | (OpType::ElemBinary { .. }, OpType::ElemBinary { .. }) => {
208 Some(FusionType::ElementWise)
209 }
210
211 (OpType::Reduce { .. }, OpType::ElemUnary { .. })
213 | (OpType::Reduce { .. }, OpType::ElemBinary { .. }) => {
214 Some(FusionType::ReductionElementWise)
215 }
216
217 (OpType::Einsum { .. }, OpType::Einsum { .. }) => Some(FusionType::EinsumChain),
219
220 _ => None,
221 }
222 }
223
224 fn estimate_fusion_speedup(&self, fusion_type: FusionType) -> u32 {
226 match fusion_type {
227 FusionType::ElementWise => 40, FusionType::ReductionElementWise => 25, FusionType::MultiReduction => 30, FusionType::EinsumChain => 20, }
232 }
233
234 fn detect_dead_nodes(
236 &self,
237 graph: &EinsumGraph,
238 tensor_consumers: &HashMap<usize, Vec<usize>>,
239 ) -> Vec<usize> {
240 let mut dead_nodes = Vec::new();
241
242 for (node_idx, node) in graph.nodes.iter().enumerate() {
243 let all_outputs_unused = node.outputs.iter().all(|&output_idx| {
245 tensor_consumers
246 .get(&output_idx)
247 .map(|consumers| consumers.is_empty())
248 .unwrap_or(true)
249 });
250
251 if all_outputs_unused {
252 dead_nodes.push(node_idx);
253 }
254 }
255
256 dead_nodes
257 }
258
259 fn detect_redundant_computations(&self, graph: &EinsumGraph) -> Vec<(usize, usize)> {
261 let mut redundant_pairs = Vec::new();
262 let mut seen: HashMap<String, Vec<usize>> = HashMap::new();
263
264 for (node_idx, node) in graph.nodes.iter().enumerate() {
265 let mut signature = format!("{:?}", node.op);
267 let mut sorted_inputs = node.inputs.clone();
268 sorted_inputs.sort_unstable();
269 signature.push_str(&format!("{:?}", sorted_inputs));
270
271 if let Some(previous_nodes) = seen.get(&signature) {
273 for &prev_idx in previous_nodes {
274 redundant_pairs.push((prev_idx, node_idx));
275 }
276 }
277
278 seen.entry(signature).or_default().push(node_idx);
279 }
280
281 redundant_pairs
282 }
283
284 fn estimate_improvement(&self, result: &OptimizationResult) -> f64 {
286 let mut total_improvement = 0.0;
287
288 for fusion in &result.fusion_opportunities {
290 total_improvement += fusion.estimated_speedup as f64;
291 }
292
293 total_improvement += result.dead_nodes.len() as f64 * 5.0;
295
296 total_improvement += result.redundant_computations.len() as f64 * 10.0;
298
299 total_improvement
300 }
301}
302
303impl Default for GraphOptimizer {
304 fn default() -> Self {
305 Self::new()
306 }
307}
308
309pub struct FusionPlanner {
311 max_fusion_depth: usize,
312}
313
314impl FusionPlanner {
315 pub fn new() -> Self {
316 FusionPlanner {
317 max_fusion_depth: 3,
318 }
319 }
320
321 pub fn with_max_depth(mut self, depth: usize) -> Self {
322 self.max_fusion_depth = depth;
323 self
324 }
325
326 pub fn plan_fusions(&self, opportunities: &[FusionOpportunity]) -> Vec<FusionOpportunity> {
328 let mut planned = Vec::new();
329 let mut fused_nodes = HashSet::new();
330
331 let mut sorted_ops = opportunities.to_vec();
333 sorted_ops.sort_by_key(|b| Reverse(b.estimated_speedup));
334
335 for fusion in sorted_ops {
336 if fused_nodes.contains(&fusion.producer_idx)
338 || fused_nodes.contains(&fusion.consumer_idx)
339 {
340 continue;
341 }
342
343 if planned.len() >= self.max_fusion_depth {
345 break;
346 }
347
348 planned.push(fusion.clone());
349 fused_nodes.insert(fusion.producer_idx);
350 fused_nodes.insert(fusion.consumer_idx);
351 }
352
353 planned
354 }
355
356 pub fn validate_plan(&self, plan: &[FusionOpportunity]) -> bool {
358 let mut used_nodes = HashSet::new();
359
360 for fusion in plan {
361 if used_nodes.contains(&fusion.producer_idx)
362 || used_nodes.contains(&fusion.consumer_idx)
363 {
364 return false;
365 }
366 used_nodes.insert(fusion.producer_idx);
367 used_nodes.insert(fusion.consumer_idx);
368 }
369
370 true
371 }
372}
373
374impl Default for FusionPlanner {
375 fn default() -> Self {
376 Self::new()
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 fn create_test_graph() -> EinsumGraph {
385 let mut graph = EinsumGraph::new();
386
387 graph.tensors.push("x".to_string()); graph.tensors.push("y".to_string()); graph.tensors.push("t2".to_string()); graph.nodes.push(EinsumNode {
394 inputs: vec![0, 1],
395 outputs: vec![2],
396 op: OpType::Einsum {
397 spec: "ab,bc->ac".into(),
398 },
399 metadata: None,
400 });
401
402 graph.tensors.push("t3".to_string()); graph.nodes.push(EinsumNode {
405 inputs: vec![2],
406 outputs: vec![3],
407 op: OpType::ElemUnary { op: "add".into() },
408 metadata: None,
409 });
410
411 graph.tensors.push("t4".to_string()); graph.nodes.push(EinsumNode {
414 inputs: vec![3],
415 outputs: vec![4],
416 op: OpType::ElemUnary { op: "mul".into() },
417 metadata: None,
418 });
419
420 graph
421 }
422
423 fn create_graph_with_dead_node() -> EinsumGraph {
424 let mut graph = create_test_graph();
425
426 graph.tensors.push("t5".to_string()); graph.nodes.push(EinsumNode {
429 inputs: vec![0],
430 outputs: vec![5],
431 op: OpType::ElemUnary { op: "add".into() },
432 metadata: None,
433 });
434
435 graph
436 }
437
438 fn create_graph_with_redundancy() -> EinsumGraph {
439 let mut graph = EinsumGraph::new();
440
441 graph.tensors.push("x".to_string()); graph.tensors.push("y".to_string()); graph.tensors.push("t2".to_string()); graph.nodes.push(EinsumNode {
448 inputs: vec![0, 1],
449 outputs: vec![2],
450 op: OpType::ElemBinary { op: "add".into() },
451 metadata: None,
452 });
453
454 graph.tensors.push("t3".to_string()); graph.nodes.push(EinsumNode {
457 inputs: vec![0, 1],
458 outputs: vec![3],
459 op: OpType::ElemBinary { op: "add".into() },
460 metadata: None,
461 });
462
463 graph
464 }
465
466 #[test]
467 fn test_optimizer_creation() {
468 let optimizer = GraphOptimizer::new();
469 assert!(optimizer.enable_fusion);
470 assert!(optimizer.enable_dead_node_elimination);
471 assert!(optimizer.enable_redundancy_detection);
472 assert_eq!(optimizer.min_fusion_benefit, 10);
473 }
474
475 #[test]
476 fn test_optimizer_builder() {
477 let optimizer = GraphOptimizer::new()
478 .with_fusion(false)
479 .with_dead_node_elimination(false)
480 .with_min_fusion_benefit(20);
481
482 assert!(!optimizer.enable_fusion);
483 assert!(!optimizer.enable_dead_node_elimination);
484 assert_eq!(optimizer.min_fusion_benefit, 20);
485 }
486
487 #[test]
488 fn test_producer_map() {
489 let graph = create_test_graph();
490 let optimizer = GraphOptimizer::new();
491 let producers = optimizer.build_producer_map(&graph);
492
493 assert_eq!(producers.get(&2), Some(&0)); assert_eq!(producers.get(&3), Some(&1)); assert_eq!(producers.get(&4), Some(&2)); }
497
498 #[test]
499 fn test_consumer_map() {
500 let graph = create_test_graph();
501 let optimizer = GraphOptimizer::new();
502 let consumers = optimizer.build_consumer_map(&graph);
503
504 assert_eq!(consumers.get(&0), Some(&vec![0])); assert_eq!(consumers.get(&2), Some(&vec![1])); assert_eq!(consumers.get(&3), Some(&vec![2])); }
508
509 #[test]
510 fn test_fusion_detection() {
511 let graph = create_test_graph();
512 let optimizer = GraphOptimizer::new();
513 let result = optimizer.analyze(&graph);
514
515 assert!(!result.fusion_opportunities.is_empty());
517 let fusion = &result.fusion_opportunities[0];
518 assert_eq!(fusion.fusion_type, FusionType::ElementWise);
519 assert!(fusion.estimated_speedup >= 10);
520 }
521
522 #[test]
523 fn test_dead_node_detection() {
524 let graph = create_graph_with_dead_node();
525 let optimizer = GraphOptimizer::new();
526 let result = optimizer.analyze(&graph);
527
528 assert!(!result.dead_nodes.is_empty());
530 assert!(result.dead_nodes.contains(&3));
531 }
532
533 #[test]
534 fn test_redundancy_detection() {
535 let graph = create_graph_with_redundancy();
536 let optimizer = GraphOptimizer::new();
537 let result = optimizer.analyze(&graph);
538
539 assert!(!result.redundant_computations.is_empty());
541 assert_eq!(result.redundant_computations[0], (0, 1));
542 }
543
544 #[test]
545 fn test_optimization_result_empty() {
546 let result = OptimizationResult::new();
547 assert!(result.is_empty());
548 assert_eq!(result.total_opportunities(), 0);
549 }
550
551 #[test]
552 fn test_optimization_result_nonempty() {
553 let mut result = OptimizationResult::new();
554 result.fusion_opportunities.push(FusionOpportunity {
555 producer_idx: 0,
556 consumer_idx: 1,
557 fusion_type: FusionType::ElementWise,
558 estimated_speedup: 40,
559 });
560 result.dead_nodes.push(2);
561
562 assert!(!result.is_empty());
563 assert_eq!(result.total_opportunities(), 2);
564 }
565
566 #[test]
567 fn test_can_fuse_elementwise() {
568 let optimizer = GraphOptimizer::new();
569
570 let producer = EinsumNode {
571 inputs: vec![0],
572 outputs: vec![1],
573 op: OpType::ElemUnary { op: "add".into() },
574 metadata: None,
575 };
576
577 let consumer = EinsumNode {
578 inputs: vec![1],
579 outputs: vec![2],
580 op: OpType::ElemUnary { op: "mul".into() },
581 metadata: None,
582 };
583
584 let fusion_type = optimizer.can_fuse(&producer, &consumer);
585 assert_eq!(fusion_type, Some(FusionType::ElementWise));
586 }
587
588 #[test]
589 fn test_fusion_planner_creation() {
590 let planner = FusionPlanner::new();
591 assert_eq!(planner.max_fusion_depth, 3);
592 }
593
594 #[test]
595 fn test_fusion_planner_with_depth() {
596 let planner = FusionPlanner::new().with_max_depth(5);
597 assert_eq!(planner.max_fusion_depth, 5);
598 }
599
600 #[test]
601 fn test_fusion_planning() {
602 let opportunities = vec![
603 FusionOpportunity {
604 producer_idx: 0,
605 consumer_idx: 1,
606 fusion_type: FusionType::ElementWise,
607 estimated_speedup: 40,
608 },
609 FusionOpportunity {
610 producer_idx: 2,
611 consumer_idx: 3,
612 fusion_type: FusionType::ReductionElementWise,
613 estimated_speedup: 25,
614 },
615 ];
616
617 let planner = FusionPlanner::new();
618 let plan = planner.plan_fusions(&opportunities);
619
620 assert_eq!(plan.len(), 2);
621 assert!(planner.validate_plan(&plan));
622 }
623
624 #[test]
625 fn test_fusion_planning_with_conflicts() {
626 let opportunities = vec![
627 FusionOpportunity {
628 producer_idx: 0,
629 consumer_idx: 1,
630 fusion_type: FusionType::ElementWise,
631 estimated_speedup: 40,
632 },
633 FusionOpportunity {
634 producer_idx: 1, consumer_idx: 2,
636 fusion_type: FusionType::ElementWise,
637 estimated_speedup: 35,
638 },
639 ];
640
641 let planner = FusionPlanner::new();
642 let plan = planner.plan_fusions(&opportunities);
643
644 assert_eq!(plan.len(), 1);
646 assert_eq!(plan[0].producer_idx, 0);
647 }
648
649 #[test]
650 fn test_estimate_improvement() {
651 let optimizer = GraphOptimizer::new();
652 let mut result = OptimizationResult::new();
653
654 result.fusion_opportunities.push(FusionOpportunity {
655 producer_idx: 0,
656 consumer_idx: 1,
657 fusion_type: FusionType::ElementWise,
658 estimated_speedup: 40,
659 });
660 result.dead_nodes.push(2);
661 result.redundant_computations.push((3, 4));
662
663 let improvement = optimizer.estimate_improvement(&result);
664 assert!(improvement > 0.0);
665 assert_eq!(improvement, 40.0 + 5.0 + 10.0); }
667
668 #[test]
669 fn test_disabled_optimizations() {
670 let graph = create_graph_with_dead_node();
671 let optimizer = GraphOptimizer::new()
672 .with_fusion(false)
673 .with_dead_node_elimination(false)
674 .with_redundancy_detection(false);
675
676 let result = optimizer.analyze(&graph);
677
678 assert!(result.fusion_opportunities.is_empty());
679 assert!(result.dead_nodes.is_empty());
680 assert!(result.redundant_computations.is_empty());
681 }
682}