1use std::collections::{HashMap, HashSet};
10use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct FusionOpportunity {
15 pub producer_idx: usize,
16 pub consumer_idx: usize,
17 pub fusion_type: FusionType,
18 pub estimated_speedup: u32, }
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum FusionType {
24 ElementWise,
26 ReductionElementWise,
28 MultiReduction,
30 EinsumChain,
32}
33
34#[derive(Debug, Clone)]
36pub struct OptimizationResult {
37 pub fusion_opportunities: Vec<FusionOpportunity>,
38 pub dead_nodes: Vec<usize>,
39 pub redundant_computations: Vec<(usize, usize)>, pub estimated_improvement: f64, }
42
43impl OptimizationResult {
44 pub fn new() -> Self {
45 OptimizationResult {
46 fusion_opportunities: Vec::new(),
47 dead_nodes: Vec::new(),
48 redundant_computations: Vec::new(),
49 estimated_improvement: 0.0,
50 }
51 }
52
53 pub fn is_empty(&self) -> bool {
54 self.fusion_opportunities.is_empty()
55 && self.dead_nodes.is_empty()
56 && self.redundant_computations.is_empty()
57 }
58
59 pub fn total_opportunities(&self) -> usize {
60 self.fusion_opportunities.len() + self.dead_nodes.len() + self.redundant_computations.len()
61 }
62}
63
64impl Default for OptimizationResult {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70pub struct GraphOptimizer {
72 enable_fusion: bool,
73 enable_dead_node_elimination: bool,
74 enable_redundancy_detection: bool,
75 min_fusion_benefit: u32,
76}
77
78impl GraphOptimizer {
79 pub fn new() -> Self {
80 GraphOptimizer {
81 enable_fusion: true,
82 enable_dead_node_elimination: true,
83 enable_redundancy_detection: true,
84 min_fusion_benefit: 10, }
86 }
87
88 pub fn with_fusion(mut self, enabled: bool) -> Self {
89 self.enable_fusion = enabled;
90 self
91 }
92
93 pub fn with_dead_node_elimination(mut self, enabled: bool) -> Self {
94 self.enable_dead_node_elimination = enabled;
95 self
96 }
97
98 pub fn with_redundancy_detection(mut self, enabled: bool) -> Self {
99 self.enable_redundancy_detection = enabled;
100 self
101 }
102
103 pub fn with_min_fusion_benefit(mut self, min_benefit: u32) -> Self {
104 self.min_fusion_benefit = min_benefit;
105 self
106 }
107
108 pub fn analyze(&self, graph: &EinsumGraph) -> OptimizationResult {
110 let mut result = OptimizationResult::new();
111
112 let tensor_producers = self.build_producer_map(graph);
114 let tensor_consumers = self.build_consumer_map(graph);
115
116 if self.enable_fusion {
118 result.fusion_opportunities =
119 self.detect_fusion_opportunities(graph, &tensor_producers, &tensor_consumers);
120 }
121
122 if self.enable_dead_node_elimination {
124 result.dead_nodes = self.detect_dead_nodes(graph, &tensor_consumers);
125 }
126
127 if self.enable_redundancy_detection {
129 result.redundant_computations = self.detect_redundant_computations(graph);
130 }
131
132 result.estimated_improvement = self.estimate_improvement(&result);
134
135 result
136 }
137
138 fn build_producer_map(&self, graph: &EinsumGraph) -> HashMap<usize, usize> {
140 let mut producers = HashMap::new();
141 for (node_idx, node) in graph.nodes.iter().enumerate() {
142 for &output_idx in &node.outputs {
143 producers.insert(output_idx, node_idx);
144 }
145 }
146 producers
147 }
148
149 fn build_consumer_map(&self, graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
151 let mut consumers: HashMap<usize, Vec<usize>> = HashMap::new();
152 for (node_idx, node) in graph.nodes.iter().enumerate() {
153 for &input_idx in &node.inputs {
154 consumers.entry(input_idx).or_default().push(node_idx);
155 }
156 }
157 consumers
158 }
159
160 fn detect_fusion_opportunities(
162 &self,
163 graph: &EinsumGraph,
164 tensor_producers: &HashMap<usize, usize>,
165 tensor_consumers: &HashMap<usize, Vec<usize>>,
166 ) -> Vec<FusionOpportunity> {
167 let mut opportunities = Vec::new();
168
169 for (node_idx, node) in graph.nodes.iter().enumerate() {
170 for &input_idx in &node.inputs {
172 if let Some(&producer_idx) = tensor_producers.get(&input_idx) {
173 let is_single_use = tensor_consumers
175 .get(&input_idx)
176 .map(|consumers| consumers.len() == 1)
177 .unwrap_or(false);
178
179 if is_single_use {
180 if let Some(fusion_type) = self.can_fuse(&graph.nodes[producer_idx], node) {
181 let estimated_speedup = self.estimate_fusion_speedup(fusion_type);
182 if estimated_speedup >= self.min_fusion_benefit {
183 opportunities.push(FusionOpportunity {
184 producer_idx,
185 consumer_idx: node_idx,
186 fusion_type,
187 estimated_speedup,
188 });
189 }
190 }
191 }
192 }
193 }
194 }
195
196 opportunities
197 }
198
199 fn can_fuse(&self, producer: &EinsumNode, consumer: &EinsumNode) -> Option<FusionType> {
201 match (&producer.op, &consumer.op) {
202 (OpType::ElemUnary { .. }, OpType::ElemUnary { .. })
204 | (OpType::ElemUnary { .. }, OpType::ElemBinary { .. })
205 | (OpType::ElemBinary { .. }, OpType::ElemUnary { .. })
206 | (OpType::ElemBinary { .. }, OpType::ElemBinary { .. }) => {
207 Some(FusionType::ElementWise)
208 }
209
210 (OpType::Reduce { .. }, OpType::ElemUnary { .. })
212 | (OpType::Reduce { .. }, OpType::ElemBinary { .. }) => {
213 Some(FusionType::ReductionElementWise)
214 }
215
216 (OpType::Einsum { .. }, OpType::Einsum { .. }) => Some(FusionType::EinsumChain),
218
219 _ => None,
220 }
221 }
222
223 fn estimate_fusion_speedup(&self, fusion_type: FusionType) -> u32 {
225 match fusion_type {
226 FusionType::ElementWise => 40, FusionType::ReductionElementWise => 25, FusionType::MultiReduction => 30, FusionType::EinsumChain => 20, }
231 }
232
233 fn detect_dead_nodes(
235 &self,
236 graph: &EinsumGraph,
237 tensor_consumers: &HashMap<usize, Vec<usize>>,
238 ) -> Vec<usize> {
239 let mut dead_nodes = Vec::new();
240
241 for (node_idx, node) in graph.nodes.iter().enumerate() {
242 let all_outputs_unused = node.outputs.iter().all(|&output_idx| {
244 tensor_consumers
245 .get(&output_idx)
246 .map(|consumers| consumers.is_empty())
247 .unwrap_or(true)
248 });
249
250 if all_outputs_unused {
251 dead_nodes.push(node_idx);
252 }
253 }
254
255 dead_nodes
256 }
257
258 fn detect_redundant_computations(&self, graph: &EinsumGraph) -> Vec<(usize, usize)> {
260 let mut redundant_pairs = Vec::new();
261 let mut seen: HashMap<String, Vec<usize>> = HashMap::new();
262
263 for (node_idx, node) in graph.nodes.iter().enumerate() {
264 let mut signature = format!("{:?}", node.op);
266 let mut sorted_inputs = node.inputs.clone();
267 sorted_inputs.sort_unstable();
268 signature.push_str(&format!("{:?}", sorted_inputs));
269
270 if let Some(previous_nodes) = seen.get(&signature) {
272 for &prev_idx in previous_nodes {
273 redundant_pairs.push((prev_idx, node_idx));
274 }
275 }
276
277 seen.entry(signature).or_default().push(node_idx);
278 }
279
280 redundant_pairs
281 }
282
283 fn estimate_improvement(&self, result: &OptimizationResult) -> f64 {
285 let mut total_improvement = 0.0;
286
287 for fusion in &result.fusion_opportunities {
289 total_improvement += fusion.estimated_speedup as f64;
290 }
291
292 total_improvement += result.dead_nodes.len() as f64 * 5.0;
294
295 total_improvement += result.redundant_computations.len() as f64 * 10.0;
297
298 total_improvement
299 }
300}
301
302impl Default for GraphOptimizer {
303 fn default() -> Self {
304 Self::new()
305 }
306}
307
308pub struct FusionPlanner {
310 max_fusion_depth: usize,
311}
312
313impl FusionPlanner {
314 pub fn new() -> Self {
315 FusionPlanner {
316 max_fusion_depth: 3,
317 }
318 }
319
320 pub fn with_max_depth(mut self, depth: usize) -> Self {
321 self.max_fusion_depth = depth;
322 self
323 }
324
325 pub fn plan_fusions(&self, opportunities: &[FusionOpportunity]) -> Vec<FusionOpportunity> {
327 let mut planned = Vec::new();
328 let mut fused_nodes = HashSet::new();
329
330 let mut sorted_ops = opportunities.to_vec();
332 sorted_ops.sort_by(|a, b| b.estimated_speedup.cmp(&a.estimated_speedup));
333
334 for fusion in sorted_ops {
335 if fused_nodes.contains(&fusion.producer_idx)
337 || fused_nodes.contains(&fusion.consumer_idx)
338 {
339 continue;
340 }
341
342 if planned.len() >= self.max_fusion_depth {
344 break;
345 }
346
347 planned.push(fusion.clone());
348 fused_nodes.insert(fusion.producer_idx);
349 fused_nodes.insert(fusion.consumer_idx);
350 }
351
352 planned
353 }
354
355 pub fn validate_plan(&self, plan: &[FusionOpportunity]) -> bool {
357 let mut used_nodes = HashSet::new();
358
359 for fusion in plan {
360 if used_nodes.contains(&fusion.producer_idx)
361 || used_nodes.contains(&fusion.consumer_idx)
362 {
363 return false;
364 }
365 used_nodes.insert(fusion.producer_idx);
366 used_nodes.insert(fusion.consumer_idx);
367 }
368
369 true
370 }
371}
372
373impl Default for FusionPlanner {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 fn create_test_graph() -> EinsumGraph {
384 let mut graph = EinsumGraph::new();
385
386 graph.tensors.push("x".to_string()); graph.tensors.push("y".to_string()); graph.tensors.push("t2".to_string()); graph.nodes.push(EinsumNode {
393 inputs: vec![0, 1],
394 outputs: vec![2],
395 op: OpType::Einsum {
396 spec: "ab,bc->ac".into(),
397 },
398 metadata: None,
399 });
400
401 graph.tensors.push("t3".to_string()); graph.nodes.push(EinsumNode {
404 inputs: vec![2],
405 outputs: vec![3],
406 op: OpType::ElemUnary { op: "add".into() },
407 metadata: None,
408 });
409
410 graph.tensors.push("t4".to_string()); graph.nodes.push(EinsumNode {
413 inputs: vec![3],
414 outputs: vec![4],
415 op: OpType::ElemUnary { op: "mul".into() },
416 metadata: None,
417 });
418
419 graph
420 }
421
422 fn create_graph_with_dead_node() -> EinsumGraph {
423 let mut graph = create_test_graph();
424
425 graph.tensors.push("t5".to_string()); graph.nodes.push(EinsumNode {
428 inputs: vec![0],
429 outputs: vec![5],
430 op: OpType::ElemUnary { op: "add".into() },
431 metadata: None,
432 });
433
434 graph
435 }
436
437 fn create_graph_with_redundancy() -> EinsumGraph {
438 let mut graph = EinsumGraph::new();
439
440 graph.tensors.push("x".to_string()); graph.tensors.push("y".to_string()); graph.tensors.push("t2".to_string()); graph.nodes.push(EinsumNode {
447 inputs: vec![0, 1],
448 outputs: vec![2],
449 op: OpType::ElemBinary { op: "add".into() },
450 metadata: None,
451 });
452
453 graph.tensors.push("t3".to_string()); graph.nodes.push(EinsumNode {
456 inputs: vec![0, 1],
457 outputs: vec![3],
458 op: OpType::ElemBinary { op: "add".into() },
459 metadata: None,
460 });
461
462 graph
463 }
464
465 #[test]
466 fn test_optimizer_creation() {
467 let optimizer = GraphOptimizer::new();
468 assert!(optimizer.enable_fusion);
469 assert!(optimizer.enable_dead_node_elimination);
470 assert!(optimizer.enable_redundancy_detection);
471 assert_eq!(optimizer.min_fusion_benefit, 10);
472 }
473
474 #[test]
475 fn test_optimizer_builder() {
476 let optimizer = GraphOptimizer::new()
477 .with_fusion(false)
478 .with_dead_node_elimination(false)
479 .with_min_fusion_benefit(20);
480
481 assert!(!optimizer.enable_fusion);
482 assert!(!optimizer.enable_dead_node_elimination);
483 assert_eq!(optimizer.min_fusion_benefit, 20);
484 }
485
486 #[test]
487 fn test_producer_map() {
488 let graph = create_test_graph();
489 let optimizer = GraphOptimizer::new();
490 let producers = optimizer.build_producer_map(&graph);
491
492 assert_eq!(producers.get(&2), Some(&0)); assert_eq!(producers.get(&3), Some(&1)); assert_eq!(producers.get(&4), Some(&2)); }
496
497 #[test]
498 fn test_consumer_map() {
499 let graph = create_test_graph();
500 let optimizer = GraphOptimizer::new();
501 let consumers = optimizer.build_consumer_map(&graph);
502
503 assert_eq!(consumers.get(&0), Some(&vec![0])); assert_eq!(consumers.get(&2), Some(&vec![1])); assert_eq!(consumers.get(&3), Some(&vec![2])); }
507
508 #[test]
509 fn test_fusion_detection() {
510 let graph = create_test_graph();
511 let optimizer = GraphOptimizer::new();
512 let result = optimizer.analyze(&graph);
513
514 assert!(!result.fusion_opportunities.is_empty());
516 let fusion = &result.fusion_opportunities[0];
517 assert_eq!(fusion.fusion_type, FusionType::ElementWise);
518 assert!(fusion.estimated_speedup >= 10);
519 }
520
521 #[test]
522 fn test_dead_node_detection() {
523 let graph = create_graph_with_dead_node();
524 let optimizer = GraphOptimizer::new();
525 let result = optimizer.analyze(&graph);
526
527 assert!(!result.dead_nodes.is_empty());
529 assert!(result.dead_nodes.contains(&3));
530 }
531
532 #[test]
533 fn test_redundancy_detection() {
534 let graph = create_graph_with_redundancy();
535 let optimizer = GraphOptimizer::new();
536 let result = optimizer.analyze(&graph);
537
538 assert!(!result.redundant_computations.is_empty());
540 assert_eq!(result.redundant_computations[0], (0, 1));
541 }
542
543 #[test]
544 fn test_optimization_result_empty() {
545 let result = OptimizationResult::new();
546 assert!(result.is_empty());
547 assert_eq!(result.total_opportunities(), 0);
548 }
549
550 #[test]
551 fn test_optimization_result_nonempty() {
552 let mut result = OptimizationResult::new();
553 result.fusion_opportunities.push(FusionOpportunity {
554 producer_idx: 0,
555 consumer_idx: 1,
556 fusion_type: FusionType::ElementWise,
557 estimated_speedup: 40,
558 });
559 result.dead_nodes.push(2);
560
561 assert!(!result.is_empty());
562 assert_eq!(result.total_opportunities(), 2);
563 }
564
565 #[test]
566 fn test_can_fuse_elementwise() {
567 let optimizer = GraphOptimizer::new();
568
569 let producer = EinsumNode {
570 inputs: vec![0],
571 outputs: vec![1],
572 op: OpType::ElemUnary { op: "add".into() },
573 metadata: None,
574 };
575
576 let consumer = EinsumNode {
577 inputs: vec![1],
578 outputs: vec![2],
579 op: OpType::ElemUnary { op: "mul".into() },
580 metadata: None,
581 };
582
583 let fusion_type = optimizer.can_fuse(&producer, &consumer);
584 assert_eq!(fusion_type, Some(FusionType::ElementWise));
585 }
586
587 #[test]
588 fn test_fusion_planner_creation() {
589 let planner = FusionPlanner::new();
590 assert_eq!(planner.max_fusion_depth, 3);
591 }
592
593 #[test]
594 fn test_fusion_planner_with_depth() {
595 let planner = FusionPlanner::new().with_max_depth(5);
596 assert_eq!(planner.max_fusion_depth, 5);
597 }
598
599 #[test]
600 fn test_fusion_planning() {
601 let opportunities = vec![
602 FusionOpportunity {
603 producer_idx: 0,
604 consumer_idx: 1,
605 fusion_type: FusionType::ElementWise,
606 estimated_speedup: 40,
607 },
608 FusionOpportunity {
609 producer_idx: 2,
610 consumer_idx: 3,
611 fusion_type: FusionType::ReductionElementWise,
612 estimated_speedup: 25,
613 },
614 ];
615
616 let planner = FusionPlanner::new();
617 let plan = planner.plan_fusions(&opportunities);
618
619 assert_eq!(plan.len(), 2);
620 assert!(planner.validate_plan(&plan));
621 }
622
623 #[test]
624 fn test_fusion_planning_with_conflicts() {
625 let opportunities = vec![
626 FusionOpportunity {
627 producer_idx: 0,
628 consumer_idx: 1,
629 fusion_type: FusionType::ElementWise,
630 estimated_speedup: 40,
631 },
632 FusionOpportunity {
633 producer_idx: 1, consumer_idx: 2,
635 fusion_type: FusionType::ElementWise,
636 estimated_speedup: 35,
637 },
638 ];
639
640 let planner = FusionPlanner::new();
641 let plan = planner.plan_fusions(&opportunities);
642
643 assert_eq!(plan.len(), 1);
645 assert_eq!(plan[0].producer_idx, 0);
646 }
647
648 #[test]
649 fn test_estimate_improvement() {
650 let optimizer = GraphOptimizer::new();
651 let mut result = OptimizationResult::new();
652
653 result.fusion_opportunities.push(FusionOpportunity {
654 producer_idx: 0,
655 consumer_idx: 1,
656 fusion_type: FusionType::ElementWise,
657 estimated_speedup: 40,
658 });
659 result.dead_nodes.push(2);
660 result.redundant_computations.push((3, 4));
661
662 let improvement = optimizer.estimate_improvement(&result);
663 assert!(improvement > 0.0);
664 assert_eq!(improvement, 40.0 + 5.0 + 10.0); }
666
667 #[test]
668 fn test_disabled_optimizations() {
669 let graph = create_graph_with_dead_node();
670 let optimizer = GraphOptimizer::new()
671 .with_fusion(false)
672 .with_dead_node_elimination(false)
673 .with_redundancy_detection(false);
674
675 let result = optimizer.analyze(&graph);
676
677 assert!(result.fusion_opportunities.is_empty());
678 assert!(result.dead_nodes.is_empty());
679 assert!(result.redundant_computations.is_empty());
680 }
681}