1use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use tensorlogic_ir::{EinsumGraph, OpType};
14use thiserror::Error;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
18pub struct NodeId(pub usize);
19
20#[derive(Error, Debug, Clone, PartialEq)]
22pub enum FusionError {
23 #[error("Fusion would create a cycle in the graph")]
24 WouldCreateCycle,
25
26 #[error("Incompatible operations for fusion: {0:?} and {1:?}")]
27 IncompatibleOps(OpType, OpType),
28
29 #[error("Fusion exceeds resource limits: {0}")]
30 ResourceLimitExceeded(String),
31
32 #[error("Invalid fusion pattern")]
33 InvalidPattern,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
38pub enum FusionPattern {
39 MatMulBias,
41 MatMulActivation,
43 BiasActivation,
45 BatchNormReLU,
47 ConvBNReLU,
49 ElementwiseChain,
51 ReduceElementwise,
53 ParallelReductions,
55 BroadcastElementwise,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
61pub enum FusionStrategy {
62 Conservative,
64 Aggressive,
66 Balanced,
68 MemoryAware,
70}
71
72#[derive(Debug, Clone, PartialEq)]
74pub struct FusionCandidate {
75 pub nodes: Vec<NodeId>,
77 pub pattern: FusionPattern,
79 pub benefit_score: f64,
81 pub memory_savings: usize,
83 pub compute_savings: f64,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct FusionConfig {
90 pub strategy: FusionStrategy,
92 pub max_fusion_size: usize,
94 pub enable_patterns: bool,
96 pub enable_vertical: bool,
98 pub enable_horizontal: bool,
100 pub enable_loop_fusion: bool,
102 pub memory_bandwidth_threshold: Option<f64>,
104 pub min_benefit_score: f64,
106}
107
108impl Default for FusionConfig {
109 fn default() -> Self {
110 Self {
111 strategy: FusionStrategy::Balanced,
112 max_fusion_size: 8,
113 enable_patterns: true,
114 enable_vertical: true,
115 enable_horizontal: true,
116 enable_loop_fusion: true,
117 memory_bandwidth_threshold: None,
118 min_benefit_score: 0.1,
119 }
120 }
121}
122
123impl FusionConfig {
124 pub fn aggressive() -> Self {
126 Self {
127 strategy: FusionStrategy::Aggressive,
128 max_fusion_size: 16,
129 min_benefit_score: 0.0,
130 ..Default::default()
131 }
132 }
133
134 pub fn conservative() -> Self {
136 Self {
137 strategy: FusionStrategy::Conservative,
138 max_fusion_size: 4,
139 enable_horizontal: false,
140 enable_loop_fusion: false,
141 min_benefit_score: 0.3,
142 ..Default::default()
143 }
144 }
145
146 pub fn memory_aware() -> Self {
148 Self {
149 strategy: FusionStrategy::MemoryAware,
150 memory_bandwidth_threshold: Some(100e9), ..Default::default()
152 }
153 }
154}
155
156#[derive(Debug, Clone)]
158pub struct FusionCostModel {
159 pub memory_access_cost: f64,
161 pub compute_cost: f64,
163 pub kernel_launch_cost: f64,
165 pub memory_bandwidth: f64,
167}
168
169impl Default for FusionCostModel {
170 fn default() -> Self {
171 Self {
172 memory_access_cost: 1.0,
173 compute_cost: 0.1,
174 kernel_launch_cost: 10.0,
175 memory_bandwidth: 100e9, }
177 }
178}
179
180impl FusionCostModel {
181 pub fn cost_separate(&self, num_ops: usize, data_size: usize) -> f64 {
183 let memory_cost = self.memory_access_cost * data_size as f64 * num_ops as f64;
184 let launch_cost = self.kernel_launch_cost * num_ops as f64;
185 memory_cost + launch_cost
186 }
187
188 pub fn cost_fused(&self, num_ops: usize, data_size: usize) -> f64 {
190 let memory_cost = self.memory_access_cost * data_size as f64 * 2.0;
192 let launch_cost = self.kernel_launch_cost;
193 let compute_overhead = self.compute_cost * num_ops as f64; memory_cost + launch_cost + compute_overhead
195 }
196
197 pub fn fusion_benefit(&self, num_ops: usize, data_size: usize) -> f64 {
199 let separate_cost = self.cost_separate(num_ops, data_size);
200 let fused_cost = self.cost_fused(num_ops, data_size);
201 (separate_cost - fused_cost) / separate_cost
202 }
203}
204
205pub struct FusionOptimizer {
207 config: FusionConfig,
208 cost_model: FusionCostModel,
209 candidates: Vec<FusionCandidate>,
210}
211
212impl FusionOptimizer {
213 pub fn new(config: FusionConfig) -> Self {
215 Self {
216 config,
217 cost_model: FusionCostModel::default(),
218 candidates: Vec::new(),
219 }
220 }
221
222 pub fn with_cost_model(config: FusionConfig, cost_model: FusionCostModel) -> Self {
224 Self {
225 config,
226 cost_model,
227 candidates: Vec::new(),
228 }
229 }
230
231 pub fn analyze(&mut self, graph: &EinsumGraph) -> Vec<FusionCandidate> {
233 self.candidates.clear();
234
235 if self.config.enable_patterns {
236 self.find_pattern_fusions(graph);
237 }
238
239 if self.config.enable_vertical {
240 self.find_vertical_fusions(graph);
241 }
242
243 if self.config.enable_horizontal {
244 self.find_horizontal_fusions(graph);
245 }
246
247 self.candidates
249 .sort_by(|a, b| b.benefit_score.partial_cmp(&a.benefit_score).unwrap());
250
251 self.candidates.clone()
252 }
253
254 fn find_pattern_fusions(&mut self, graph: &EinsumGraph) {
256 for node_id in 0..graph.nodes.len() {
258 let node_id = NodeId(node_id);
259 let node = &graph.nodes[node_id.0];
260
261 if matches!(node.op, OpType::Einsum { .. }) {
263 let consumers = self.find_consumers(graph, node_id);
265 for consumer in consumers {
266 let consumer_node = &graph.nodes[consumer.0];
267 if matches!(
268 consumer_node.op,
269 OpType::ElemUnary { .. } | OpType::ElemBinary { .. }
270 ) {
271 let benefit = self.estimate_pattern_benefit(2, 1024); if benefit >= self.config.min_benefit_score {
274 self.candidates.push(FusionCandidate {
275 nodes: vec![node_id, consumer],
276 pattern: FusionPattern::MatMulActivation,
277 benefit_score: benefit,
278 memory_savings: 1024 * 4, compute_savings: 0.0,
280 });
281 }
282 }
283 }
284 }
285 }
286 }
287
288 fn find_vertical_fusions(&mut self, graph: &EinsumGraph) {
290 for node_id in 0..graph.nodes.len() {
291 let node_id = NodeId(node_id);
292 let consumers = self.find_consumers(graph, node_id);
293
294 if consumers.len() == 1 {
296 let consumer = consumers[0];
297 if self.can_fuse_vertically(graph, node_id, consumer) {
298 let benefit = self.cost_model.fusion_benefit(2, 1024);
299
300 if benefit >= self.config.min_benefit_score {
301 self.candidates.push(FusionCandidate {
302 nodes: vec![node_id, consumer],
303 pattern: FusionPattern::ElementwiseChain,
304 benefit_score: benefit,
305 memory_savings: 1024 * 4,
306 compute_savings: 0.0,
307 });
308 }
309 }
310 }
311 }
312 }
313
314 fn find_horizontal_fusions(&mut self, graph: &EinsumGraph) {
316 let _independent_groups: Vec<Vec<NodeId>> = Vec::new();
317
318 let mut depth_groups: HashMap<usize, Vec<NodeId>> = HashMap::new();
320
321 for node_id in 0..graph.nodes.len() {
322 let depth = self.compute_depth(graph, NodeId(node_id));
323 depth_groups.entry(depth).or_default().push(NodeId(node_id));
324 }
325
326 for (_, nodes) in depth_groups {
328 if nodes.len() >= 2 {
329 for i in 0..nodes.len() {
331 for j in i + 1..nodes.len() {
332 if self.are_independent(graph, nodes[i], nodes[j])
333 && self.have_similar_ops(graph, nodes[i], nodes[j])
334 {
335 let benefit = self.cost_model.fusion_benefit(2, 512);
336
337 if benefit >= self.config.min_benefit_score {
338 self.candidates.push(FusionCandidate {
339 nodes: vec![nodes[i], nodes[j]],
340 pattern: FusionPattern::ParallelReductions,
341 benefit_score: benefit * 0.8, memory_savings: 512 * 4,
343 compute_savings: 0.0,
344 });
345 }
346 }
347 }
348 }
349 }
350 }
351 }
352
353 fn can_fuse_vertically(
355 &self,
356 _graph: &EinsumGraph,
357 _producer: NodeId,
358 _consumer: NodeId,
359 ) -> bool {
360 true }
367
368 fn are_independent(&self, graph: &EinsumGraph, a: NodeId, b: NodeId) -> bool {
370 let a_deps = self.get_all_dependencies(graph, a);
372 let b_deps = self.get_all_dependencies(graph, b);
373
374 !a_deps.contains(&b) && !b_deps.contains(&a)
375 }
376
377 fn have_similar_ops(&self, graph: &EinsumGraph, a: NodeId, b: NodeId) -> bool {
379 let op_a = &graph.nodes[a.0].op;
380 let op_b = &graph.nodes[b.0].op;
381
382 std::mem::discriminant(op_a) == std::mem::discriminant(op_b)
383 }
384
385 fn find_consumers(&self, graph: &EinsumGraph, producer: NodeId) -> Vec<NodeId> {
387 let mut consumers = Vec::new();
388
389 for (i, node) in graph.nodes.iter().enumerate() {
390 if node.inputs.iter().any(|&n| NodeId(n) == producer) {
391 consumers.push(NodeId(i));
392 }
393 }
394
395 consumers
396 }
397
398 fn get_all_dependencies(&self, graph: &EinsumGraph, node_id: NodeId) -> HashSet<NodeId> {
400 let mut deps = HashSet::new();
401 let mut to_visit = vec![node_id];
402
403 while let Some(current) = to_visit.pop() {
404 if deps.contains(¤t) {
405 continue;
406 }
407 deps.insert(current);
408
409 let node = &graph.nodes[current.0];
410 for &input in &node.inputs {
411 to_visit.push(NodeId(input));
412 }
413 }
414
415 deps
416 }
417
418 #[allow(clippy::only_used_in_recursion)]
420 fn compute_depth(&self, graph: &EinsumGraph, node_id: NodeId) -> usize {
421 let node = &graph.nodes[node_id.0];
422
423 if node.inputs.is_empty() {
424 0
425 } else {
426 1 + node
427 .inputs
428 .iter()
429 .map(|&input| self.compute_depth(graph, NodeId(input)))
430 .max()
431 .unwrap_or(0)
432 }
433 }
434
435 fn estimate_pattern_benefit(&self, num_ops: usize, data_size: usize) -> f64 {
437 match self.config.strategy {
438 FusionStrategy::Aggressive => self.cost_model.fusion_benefit(num_ops, data_size) * 1.2,
439 FusionStrategy::Conservative => {
440 self.cost_model.fusion_benefit(num_ops, data_size) * 0.8
441 }
442 FusionStrategy::Balanced => self.cost_model.fusion_benefit(num_ops, data_size),
443 FusionStrategy::MemoryAware => {
444 let base_benefit = self.cost_model.fusion_benefit(num_ops, data_size);
445 base_benefit * 1.5
447 }
448 }
449 }
450
451 pub fn apply_fusions(
453 &self,
454 graph: &EinsumGraph,
455 _candidates: &[FusionCandidate],
456 ) -> Result<EinsumGraph, FusionError> {
457 Ok(graph.clone())
460 }
461
462 pub fn stats(&self) -> FusionStats {
464 let total_candidates = self.candidates.len();
465 let total_memory_savings: usize = self.candidates.iter().map(|c| c.memory_savings).sum();
466 let avg_benefit_score = if total_candidates > 0 {
467 self.candidates.iter().map(|c| c.benefit_score).sum::<f64>() / total_candidates as f64
468 } else {
469 0.0
470 };
471
472 let mut pattern_counts = HashMap::new();
473 for candidate in &self.candidates {
474 *pattern_counts.entry(candidate.pattern).or_insert(0) += 1;
475 }
476
477 FusionStats {
478 total_candidates,
479 total_memory_savings,
480 avg_benefit_score,
481 pattern_distribution: pattern_counts,
482 }
483 }
484}
485
486#[derive(Debug, Clone, Serialize, Deserialize)]
488pub struct FusionStats {
489 pub total_candidates: usize,
491 pub total_memory_savings: usize,
493 pub avg_benefit_score: f64,
495 pub pattern_distribution: HashMap<FusionPattern, usize>,
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502 use tensorlogic_ir::EinsumNode;
503
504 fn create_test_graph() -> EinsumGraph {
505 let mut graph = EinsumGraph::new();
506
507 graph.nodes.push(EinsumNode {
509 op: OpType::Einsum {
510 spec: "ij,jk->ik".to_string(),
511 },
512 inputs: vec![],
513 outputs: vec![0],
514 metadata: Default::default(),
515 });
516
517 graph.nodes.push(EinsumNode {
518 op: OpType::ElemUnary {
519 op: "relu".to_string(),
520 },
521 inputs: vec![0],
522 outputs: vec![1],
523 metadata: Default::default(),
524 });
525
526 graph
527 }
528
529 #[test]
530 fn test_fusion_config() {
531 let config = FusionConfig::aggressive();
532 assert_eq!(config.strategy, FusionStrategy::Aggressive);
533 assert!(config.max_fusion_size >= FusionConfig::default().max_fusion_size);
534
535 let config = FusionConfig::conservative();
536 assert_eq!(config.strategy, FusionStrategy::Conservative);
537 }
538
539 #[test]
540 fn test_cost_model() {
541 let model = FusionCostModel::default();
542
543 let benefit = model.fusion_benefit(3, 1024);
544 assert!(benefit > 0.0);
545 assert!(benefit < 1.0);
546
547 let benefit_more = model.fusion_benefit(5, 1024);
549 assert!(benefit_more > benefit);
550 }
551
552 #[test]
553 fn test_fusion_optimizer_creation() {
554 let config = FusionConfig::default();
555 let optimizer = FusionOptimizer::new(config);
556 assert_eq!(optimizer.candidates.len(), 0);
557 }
558
559 #[test]
560 fn test_fusion_analysis() {
561 let graph = create_test_graph();
562 let config = FusionConfig {
564 min_benefit_score: 0.0,
565 ..FusionConfig::default()
566 };
567 let mut optimizer = FusionOptimizer::new(config);
568
569 let candidates = optimizer.analyze(&graph);
570 assert!(!candidates.is_empty());
572 }
573
574 #[test]
575 fn test_consumer_finding() {
576 let graph = create_test_graph();
577 let optimizer = FusionOptimizer::new(FusionConfig::default());
578
579 let consumers = optimizer.find_consumers(&graph, NodeId(0));
580 assert_eq!(consumers.len(), 1);
581 assert_eq!(consumers[0], NodeId(1));
582 }
583
584 #[test]
585 fn test_depth_computation() {
586 let graph = create_test_graph();
587 let optimizer = FusionOptimizer::new(FusionConfig::default());
588
589 assert_eq!(optimizer.compute_depth(&graph, NodeId(0)), 0);
590 assert_eq!(optimizer.compute_depth(&graph, NodeId(1)), 1);
591 }
592
593 #[test]
594 fn test_independence_check() {
595 let mut graph = create_test_graph();
596
597 graph.nodes.push(EinsumNode {
599 op: OpType::ElemUnary {
600 op: "tanh".to_string(),
601 },
602 inputs: vec![],
603 outputs: vec![2],
604 metadata: Default::default(),
605 });
606
607 let optimizer = FusionOptimizer::new(FusionConfig::default());
608
609 assert!(!optimizer.are_independent(&graph, NodeId(0), NodeId(1)));
611
612 assert!(optimizer.are_independent(&graph, NodeId(0), NodeId(2)));
614 }
615
616 #[test]
617 fn test_fusion_stats() {
618 let graph = create_test_graph();
619 let config = FusionConfig {
621 min_benefit_score: 0.0,
622 ..FusionConfig::default()
623 };
624 let mut optimizer = FusionOptimizer::new(config);
625
626 optimizer.analyze(&graph);
627 let stats = optimizer.stats();
628
629 assert!(stats.total_candidates > 0);
630 assert!(stats.avg_benefit_score >= 0.0);
631 }
632
633 #[test]
634 fn test_similar_ops_check() {
635 let graph = create_test_graph();
636 let optimizer = FusionOptimizer::new(FusionConfig::default());
637
638 assert!(optimizer.have_similar_ops(&graph, NodeId(0), NodeId(0)));
640
641 assert!(!optimizer.have_similar_ops(&graph, NodeId(0), NodeId(1)));
643 }
644}