1use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use tensorlogic_ir::{EinsumGraph, OpType};
14use thiserror::Error;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
18pub struct NodeId(pub usize);
19
20#[derive(Error, Debug, Clone, PartialEq)]
22pub enum FusionError {
23 #[error("Fusion would create a cycle in the graph")]
24 WouldCreateCycle,
25
26 #[error("Incompatible operations for fusion: {0:?} and {1:?}")]
27 IncompatibleOps(OpType, OpType),
28
29 #[error("Fusion exceeds resource limits: {0}")]
30 ResourceLimitExceeded(String),
31
32 #[error("Invalid fusion pattern")]
33 InvalidPattern,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
38pub enum FusionPattern {
39 MatMulBias,
41 MatMulActivation,
43 BiasActivation,
45 BatchNormReLU,
47 ConvBNReLU,
49 ElementwiseChain,
51 ReduceElementwise,
53 ParallelReductions,
55 BroadcastElementwise,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
61pub enum FusionStrategy {
62 Conservative,
64 Aggressive,
66 Balanced,
68 MemoryAware,
70}
71
72#[derive(Debug, Clone, PartialEq)]
74pub struct FusionCandidate {
75 pub nodes: Vec<NodeId>,
77 pub pattern: FusionPattern,
79 pub benefit_score: f64,
81 pub memory_savings: usize,
83 pub compute_savings: f64,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct FusionConfig {
90 pub strategy: FusionStrategy,
92 pub max_fusion_size: usize,
94 pub enable_patterns: bool,
96 pub enable_vertical: bool,
98 pub enable_horizontal: bool,
100 pub enable_loop_fusion: bool,
102 pub memory_bandwidth_threshold: Option<f64>,
104 pub min_benefit_score: f64,
106}
107
108impl Default for FusionConfig {
109 fn default() -> Self {
110 Self {
111 strategy: FusionStrategy::Balanced,
112 max_fusion_size: 8,
113 enable_patterns: true,
114 enable_vertical: true,
115 enable_horizontal: true,
116 enable_loop_fusion: true,
117 memory_bandwidth_threshold: None,
118 min_benefit_score: 0.1,
119 }
120 }
121}
122
123impl FusionConfig {
124 pub fn aggressive() -> Self {
126 Self {
127 strategy: FusionStrategy::Aggressive,
128 max_fusion_size: 16,
129 min_benefit_score: 0.0,
130 ..Default::default()
131 }
132 }
133
134 pub fn conservative() -> Self {
136 Self {
137 strategy: FusionStrategy::Conservative,
138 max_fusion_size: 4,
139 enable_horizontal: false,
140 enable_loop_fusion: false,
141 min_benefit_score: 0.3,
142 ..Default::default()
143 }
144 }
145
146 pub fn memory_aware() -> Self {
148 Self {
149 strategy: FusionStrategy::MemoryAware,
150 memory_bandwidth_threshold: Some(100e9), ..Default::default()
152 }
153 }
154}
155
156#[derive(Debug, Clone)]
158pub struct FusionCostModel {
159 pub memory_access_cost: f64,
161 pub compute_cost: f64,
163 pub kernel_launch_cost: f64,
165 pub memory_bandwidth: f64,
167}
168
169impl Default for FusionCostModel {
170 fn default() -> Self {
171 Self {
172 memory_access_cost: 1.0,
173 compute_cost: 0.1,
174 kernel_launch_cost: 10.0,
175 memory_bandwidth: 100e9, }
177 }
178}
179
180impl FusionCostModel {
181 pub fn cost_separate(&self, num_ops: usize, data_size: usize) -> f64 {
183 let memory_cost = self.memory_access_cost * data_size as f64 * num_ops as f64;
184 let launch_cost = self.kernel_launch_cost * num_ops as f64;
185 memory_cost + launch_cost
186 }
187
188 pub fn cost_fused(&self, num_ops: usize, data_size: usize) -> f64 {
190 let memory_cost = self.memory_access_cost * data_size as f64 * 2.0;
192 let launch_cost = self.kernel_launch_cost;
193 let compute_overhead = self.compute_cost * num_ops as f64; memory_cost + launch_cost + compute_overhead
195 }
196
197 pub fn fusion_benefit(&self, num_ops: usize, data_size: usize) -> f64 {
199 let separate_cost = self.cost_separate(num_ops, data_size);
200 let fused_cost = self.cost_fused(num_ops, data_size);
201 (separate_cost - fused_cost) / separate_cost
202 }
203}
204
205pub struct FusionOptimizer {
207 config: FusionConfig,
208 cost_model: FusionCostModel,
209 candidates: Vec<FusionCandidate>,
210}
211
212impl FusionOptimizer {
213 pub fn new(config: FusionConfig) -> Self {
215 Self {
216 config,
217 cost_model: FusionCostModel::default(),
218 candidates: Vec::new(),
219 }
220 }
221
222 pub fn with_cost_model(config: FusionConfig, cost_model: FusionCostModel) -> Self {
224 Self {
225 config,
226 cost_model,
227 candidates: Vec::new(),
228 }
229 }
230
231 pub fn analyze(&mut self, graph: &EinsumGraph) -> Vec<FusionCandidate> {
233 self.candidates.clear();
234
235 if self.config.enable_patterns {
236 self.find_pattern_fusions(graph);
237 }
238
239 if self.config.enable_vertical {
240 self.find_vertical_fusions(graph);
241 }
242
243 if self.config.enable_horizontal {
244 self.find_horizontal_fusions(graph);
245 }
246
247 self.candidates.sort_by(|a, b| {
249 b.benefit_score
250 .partial_cmp(&a.benefit_score)
251 .unwrap_or(std::cmp::Ordering::Equal)
252 });
253
254 self.candidates.clone()
255 }
256
257 fn find_pattern_fusions(&mut self, graph: &EinsumGraph) {
259 for node_id in 0..graph.nodes.len() {
261 let node_id = NodeId(node_id);
262 let node = &graph.nodes[node_id.0];
263
264 if matches!(node.op, OpType::Einsum { .. }) {
266 let consumers = self.find_consumers(graph, node_id);
268 for consumer in consumers {
269 let consumer_node = &graph.nodes[consumer.0];
270 if matches!(
271 consumer_node.op,
272 OpType::ElemUnary { .. } | OpType::ElemBinary { .. }
273 ) {
274 let benefit = self.estimate_pattern_benefit(2, 1024); if benefit >= self.config.min_benefit_score {
277 self.candidates.push(FusionCandidate {
278 nodes: vec![node_id, consumer],
279 pattern: FusionPattern::MatMulActivation,
280 benefit_score: benefit,
281 memory_savings: 1024 * 4, compute_savings: 0.0,
283 });
284 }
285 }
286 }
287 }
288 }
289 }
290
291 fn find_vertical_fusions(&mut self, graph: &EinsumGraph) {
293 for node_id in 0..graph.nodes.len() {
294 let node_id = NodeId(node_id);
295 let consumers = self.find_consumers(graph, node_id);
296
297 if consumers.len() == 1 {
299 let consumer = consumers[0];
300 if self.can_fuse_vertically(graph, node_id, consumer) {
301 let benefit = self.cost_model.fusion_benefit(2, 1024);
302
303 if benefit >= self.config.min_benefit_score {
304 self.candidates.push(FusionCandidate {
305 nodes: vec![node_id, consumer],
306 pattern: FusionPattern::ElementwiseChain,
307 benefit_score: benefit,
308 memory_savings: 1024 * 4,
309 compute_savings: 0.0,
310 });
311 }
312 }
313 }
314 }
315 }
316
317 fn find_horizontal_fusions(&mut self, graph: &EinsumGraph) {
319 let _independent_groups: Vec<Vec<NodeId>> = Vec::new();
320
321 let mut depth_groups: HashMap<usize, Vec<NodeId>> = HashMap::new();
323
324 for node_id in 0..graph.nodes.len() {
325 let depth = self.compute_depth(graph, NodeId(node_id));
326 depth_groups.entry(depth).or_default().push(NodeId(node_id));
327 }
328
329 for (_, nodes) in depth_groups {
331 if nodes.len() >= 2 {
332 for i in 0..nodes.len() {
334 for j in i + 1..nodes.len() {
335 if self.are_independent(graph, nodes[i], nodes[j])
336 && self.have_similar_ops(graph, nodes[i], nodes[j])
337 {
338 let benefit = self.cost_model.fusion_benefit(2, 512);
339
340 if benefit >= self.config.min_benefit_score {
341 self.candidates.push(FusionCandidate {
342 nodes: vec![nodes[i], nodes[j]],
343 pattern: FusionPattern::ParallelReductions,
344 benefit_score: benefit * 0.8, memory_savings: 512 * 4,
346 compute_savings: 0.0,
347 });
348 }
349 }
350 }
351 }
352 }
353 }
354 }
355
356 fn can_fuse_vertically(
358 &self,
359 _graph: &EinsumGraph,
360 _producer: NodeId,
361 _consumer: NodeId,
362 ) -> bool {
363 true }
370
371 fn are_independent(&self, graph: &EinsumGraph, a: NodeId, b: NodeId) -> bool {
373 let a_deps = self.get_all_dependencies(graph, a);
375 let b_deps = self.get_all_dependencies(graph, b);
376
377 !a_deps.contains(&b) && !b_deps.contains(&a)
378 }
379
380 fn have_similar_ops(&self, graph: &EinsumGraph, a: NodeId, b: NodeId) -> bool {
382 let op_a = &graph.nodes[a.0].op;
383 let op_b = &graph.nodes[b.0].op;
384
385 std::mem::discriminant(op_a) == std::mem::discriminant(op_b)
386 }
387
388 fn find_consumers(&self, graph: &EinsumGraph, producer: NodeId) -> Vec<NodeId> {
390 let mut consumers = Vec::new();
391
392 for (i, node) in graph.nodes.iter().enumerate() {
393 if node.inputs.iter().any(|&n| NodeId(n) == producer) {
394 consumers.push(NodeId(i));
395 }
396 }
397
398 consumers
399 }
400
401 fn get_all_dependencies(&self, graph: &EinsumGraph, node_id: NodeId) -> HashSet<NodeId> {
403 let mut deps = HashSet::new();
404 let mut to_visit = vec![node_id];
405
406 while let Some(current) = to_visit.pop() {
407 if deps.contains(¤t) {
408 continue;
409 }
410 deps.insert(current);
411
412 let node = &graph.nodes[current.0];
413 for &input in &node.inputs {
414 to_visit.push(NodeId(input));
415 }
416 }
417
418 deps
419 }
420
421 #[allow(clippy::only_used_in_recursion)]
423 fn compute_depth(&self, graph: &EinsumGraph, node_id: NodeId) -> usize {
424 let node = &graph.nodes[node_id.0];
425
426 if node.inputs.is_empty() {
427 0
428 } else {
429 1 + node
430 .inputs
431 .iter()
432 .map(|&input| self.compute_depth(graph, NodeId(input)))
433 .max()
434 .unwrap_or(0)
435 }
436 }
437
438 fn estimate_pattern_benefit(&self, num_ops: usize, data_size: usize) -> f64 {
440 match self.config.strategy {
441 FusionStrategy::Aggressive => self.cost_model.fusion_benefit(num_ops, data_size) * 1.2,
442 FusionStrategy::Conservative => {
443 self.cost_model.fusion_benefit(num_ops, data_size) * 0.8
444 }
445 FusionStrategy::Balanced => self.cost_model.fusion_benefit(num_ops, data_size),
446 FusionStrategy::MemoryAware => {
447 let base_benefit = self.cost_model.fusion_benefit(num_ops, data_size);
448 base_benefit * 1.5
450 }
451 }
452 }
453
454 pub fn apply_fusions(
456 &self,
457 graph: &EinsumGraph,
458 _candidates: &[FusionCandidate],
459 ) -> Result<EinsumGraph, FusionError> {
460 Ok(graph.clone())
463 }
464
465 pub fn stats(&self) -> FusionStats {
467 let total_candidates = self.candidates.len();
468 let total_memory_savings: usize = self.candidates.iter().map(|c| c.memory_savings).sum();
469 let avg_benefit_score = if total_candidates > 0 {
470 self.candidates.iter().map(|c| c.benefit_score).sum::<f64>() / total_candidates as f64
471 } else {
472 0.0
473 };
474
475 let mut pattern_counts = HashMap::new();
476 for candidate in &self.candidates {
477 *pattern_counts.entry(candidate.pattern).or_insert(0) += 1;
478 }
479
480 FusionStats {
481 total_candidates,
482 total_memory_savings,
483 avg_benefit_score,
484 pattern_distribution: pattern_counts,
485 }
486 }
487}
488
489#[derive(Debug, Clone, Serialize, Deserialize)]
491pub struct FusionStats {
492 pub total_candidates: usize,
494 pub total_memory_savings: usize,
496 pub avg_benefit_score: f64,
498 pub pattern_distribution: HashMap<FusionPattern, usize>,
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505 use tensorlogic_ir::EinsumNode;
506
507 fn create_test_graph() -> EinsumGraph {
508 let mut graph = EinsumGraph::new();
509
510 graph.nodes.push(EinsumNode {
512 op: OpType::Einsum {
513 spec: "ij,jk->ik".to_string(),
514 },
515 inputs: vec![],
516 outputs: vec![0],
517 metadata: Default::default(),
518 });
519
520 graph.nodes.push(EinsumNode {
521 op: OpType::ElemUnary {
522 op: "relu".to_string(),
523 },
524 inputs: vec![0],
525 outputs: vec![1],
526 metadata: Default::default(),
527 });
528
529 graph
530 }
531
532 #[test]
533 fn test_fusion_config() {
534 let config = FusionConfig::aggressive();
535 assert_eq!(config.strategy, FusionStrategy::Aggressive);
536 assert!(config.max_fusion_size >= FusionConfig::default().max_fusion_size);
537
538 let config = FusionConfig::conservative();
539 assert_eq!(config.strategy, FusionStrategy::Conservative);
540 }
541
542 #[test]
543 fn test_cost_model() {
544 let model = FusionCostModel::default();
545
546 let benefit = model.fusion_benefit(3, 1024);
547 assert!(benefit > 0.0);
548 assert!(benefit < 1.0);
549
550 let benefit_more = model.fusion_benefit(5, 1024);
552 assert!(benefit_more > benefit);
553 }
554
555 #[test]
556 fn test_fusion_optimizer_creation() {
557 let config = FusionConfig::default();
558 let optimizer = FusionOptimizer::new(config);
559 assert_eq!(optimizer.candidates.len(), 0);
560 }
561
562 #[test]
563 fn test_fusion_analysis() {
564 let graph = create_test_graph();
565 let config = FusionConfig {
567 min_benefit_score: 0.0,
568 ..FusionConfig::default()
569 };
570 let mut optimizer = FusionOptimizer::new(config);
571
572 let candidates = optimizer.analyze(&graph);
573 assert!(!candidates.is_empty());
575 }
576
577 #[test]
578 fn test_consumer_finding() {
579 let graph = create_test_graph();
580 let optimizer = FusionOptimizer::new(FusionConfig::default());
581
582 let consumers = optimizer.find_consumers(&graph, NodeId(0));
583 assert_eq!(consumers.len(), 1);
584 assert_eq!(consumers[0], NodeId(1));
585 }
586
587 #[test]
588 fn test_depth_computation() {
589 let graph = create_test_graph();
590 let optimizer = FusionOptimizer::new(FusionConfig::default());
591
592 assert_eq!(optimizer.compute_depth(&graph, NodeId(0)), 0);
593 assert_eq!(optimizer.compute_depth(&graph, NodeId(1)), 1);
594 }
595
596 #[test]
597 fn test_independence_check() {
598 let mut graph = create_test_graph();
599
600 graph.nodes.push(EinsumNode {
602 op: OpType::ElemUnary {
603 op: "tanh".to_string(),
604 },
605 inputs: vec![],
606 outputs: vec![2],
607 metadata: Default::default(),
608 });
609
610 let optimizer = FusionOptimizer::new(FusionConfig::default());
611
612 assert!(!optimizer.are_independent(&graph, NodeId(0), NodeId(1)));
614
615 assert!(optimizer.are_independent(&graph, NodeId(0), NodeId(2)));
617 }
618
619 #[test]
620 fn test_fusion_stats() {
621 let graph = create_test_graph();
622 let config = FusionConfig {
624 min_benefit_score: 0.0,
625 ..FusionConfig::default()
626 };
627 let mut optimizer = FusionOptimizer::new(config);
628
629 optimizer.analyze(&graph);
630 let stats = optimizer.stats();
631
632 assert!(stats.total_candidates > 0);
633 assert!(stats.avg_benefit_score >= 0.0);
634 }
635
636 #[test]
637 fn test_similar_ops_check() {
638 let graph = create_test_graph();
639 let optimizer = FusionOptimizer::new(FusionConfig::default());
640
641 assert!(optimizer.have_similar_ops(&graph, NodeId(0), NodeId(0)));
643
644 assert!(!optimizer.have_similar_ops(&graph, NodeId(0), NodeId(1)));
646 }
647}