Skip to main content

tensorlogic_infer/
fusion.rs

1//! Advanced kernel fusion for optimized execution.
2//!
3//! This module provides sophisticated kernel fusion capabilities:
4//! - Pattern-based fusion (common operator patterns)
5//! - Vertical fusion (producer-consumer chains)
6//! - Horizontal fusion (independent parallel operations)
7//! - Loop fusion for reductions
8//! - Memory bandwidth-aware fusion decisions
9//! - Cost modeling for fusion trade-offs
10
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use tensorlogic_ir::{EinsumGraph, OpType};
14use thiserror::Error;
15
16/// Node identifier (0-based index into graph.nodes).
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
18pub struct NodeId(pub usize);
19
20/// Fusion-related errors.
21#[derive(Error, Debug, Clone, PartialEq)]
22pub enum FusionError {
23    #[error("Fusion would create a cycle in the graph")]
24    WouldCreateCycle,
25
26    #[error("Incompatible operations for fusion: {0:?} and {1:?}")]
27    IncompatibleOps(OpType, OpType),
28
29    #[error("Fusion exceeds resource limits: {0}")]
30    ResourceLimitExceeded(String),
31
32    #[error("Invalid fusion pattern")]
33    InvalidPattern,
34}
35
36/// Fusion pattern types.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
38pub enum FusionPattern {
39    /// Matrix multiplication followed by bias addition
40    MatMulBias,
41    /// Matrix multiplication followed by activation (ReLU, tanh, etc.)
42    MatMulActivation,
43    /// Bias addition followed by activation
44    BiasActivation,
45    /// BatchNorm + ReLU fusion
46    BatchNormReLU,
47    /// Conv + BatchNorm + ReLU
48    ConvBNReLU,
49    /// Elementwise operations chain
50    ElementwiseChain,
51    /// Reduction followed by elementwise
52    ReduceElementwise,
53    /// Multiple independent reductions (horizontal fusion)
54    ParallelReductions,
55    /// Broadcast followed by elementwise
56    BroadcastElementwise,
57}
58
59/// Fusion strategy.
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
61pub enum FusionStrategy {
62    /// Conservative - only fuse proven beneficial patterns
63    Conservative,
64    /// Aggressive - fuse as much as possible
65    Aggressive,
66    /// Balanced - consider cost model
67    Balanced,
68    /// Memory-aware - prioritize memory bandwidth reduction
69    MemoryAware,
70}
71
72/// Fusion candidate representing potential fusion opportunity.
73#[derive(Debug, Clone, PartialEq)]
74pub struct FusionCandidate {
75    /// Nodes to be fused
76    pub nodes: Vec<NodeId>,
77    /// Pattern type
78    pub pattern: FusionPattern,
79    /// Estimated benefit score (higher is better)
80    pub benefit_score: f64,
81    /// Estimated memory savings (bytes)
82    pub memory_savings: usize,
83    /// Estimated compute reduction (FLOPS)
84    pub compute_savings: f64,
85}
86
87/// Fusion configuration.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct FusionConfig {
90    /// Fusion strategy
91    pub strategy: FusionStrategy,
92    /// Maximum nodes per fused kernel
93    pub max_fusion_size: usize,
94    /// Enable pattern-based fusion
95    pub enable_patterns: bool,
96    /// Enable vertical fusion
97    pub enable_vertical: bool,
98    /// Enable horizontal fusion
99    pub enable_horizontal: bool,
100    /// Enable loop fusion
101    pub enable_loop_fusion: bool,
102    /// Memory bandwidth threshold (bytes/s)
103    pub memory_bandwidth_threshold: Option<f64>,
104    /// Minimum benefit score to apply fusion
105    pub min_benefit_score: f64,
106}
107
108impl Default for FusionConfig {
109    fn default() -> Self {
110        Self {
111            strategy: FusionStrategy::Balanced,
112            max_fusion_size: 8,
113            enable_patterns: true,
114            enable_vertical: true,
115            enable_horizontal: true,
116            enable_loop_fusion: true,
117            memory_bandwidth_threshold: None,
118            min_benefit_score: 0.1,
119        }
120    }
121}
122
123impl FusionConfig {
124    /// Create aggressive fusion configuration.
125    pub fn aggressive() -> Self {
126        Self {
127            strategy: FusionStrategy::Aggressive,
128            max_fusion_size: 16,
129            min_benefit_score: 0.0,
130            ..Default::default()
131        }
132    }
133
134    /// Create conservative fusion configuration.
135    pub fn conservative() -> Self {
136        Self {
137            strategy: FusionStrategy::Conservative,
138            max_fusion_size: 4,
139            enable_horizontal: false,
140            enable_loop_fusion: false,
141            min_benefit_score: 0.3,
142            ..Default::default()
143        }
144    }
145
146    /// Create memory-aware fusion configuration.
147    pub fn memory_aware() -> Self {
148        Self {
149            strategy: FusionStrategy::MemoryAware,
150            memory_bandwidth_threshold: Some(100e9), // 100 GB/s
151            ..Default::default()
152        }
153    }
154}
155
156/// Cost model for fusion decisions.
157#[derive(Debug, Clone)]
158pub struct FusionCostModel {
159    /// Cost of memory access (relative units)
160    pub memory_access_cost: f64,
161    /// Cost of compute operation (relative units)
162    pub compute_cost: f64,
163    /// Cost of kernel launch overhead
164    pub kernel_launch_cost: f64,
165    /// Memory bandwidth (bytes/second)
166    pub memory_bandwidth: f64,
167}
168
169impl Default for FusionCostModel {
170    fn default() -> Self {
171        Self {
172            memory_access_cost: 1.0,
173            compute_cost: 0.1,
174            kernel_launch_cost: 10.0,
175            memory_bandwidth: 100e9, // 100 GB/s typical
176        }
177    }
178}
179
180impl FusionCostModel {
181    /// Estimate cost of executing operations separately.
182    pub fn cost_separate(&self, num_ops: usize, data_size: usize) -> f64 {
183        let memory_cost = self.memory_access_cost * data_size as f64 * num_ops as f64;
184        let launch_cost = self.kernel_launch_cost * num_ops as f64;
185        memory_cost + launch_cost
186    }
187
188    /// Estimate cost of executing operations fused.
189    pub fn cost_fused(&self, num_ops: usize, data_size: usize) -> f64 {
190        // Fused operations read data once, write once
191        let memory_cost = self.memory_access_cost * data_size as f64 * 2.0;
192        let launch_cost = self.kernel_launch_cost;
193        let compute_overhead = self.compute_cost * num_ops as f64; // slight overhead
194        memory_cost + launch_cost + compute_overhead
195    }
196
197    /// Calculate fusion benefit.
198    pub fn fusion_benefit(&self, num_ops: usize, data_size: usize) -> f64 {
199        let separate_cost = self.cost_separate(num_ops, data_size);
200        let fused_cost = self.cost_fused(num_ops, data_size);
201        (separate_cost - fused_cost) / separate_cost
202    }
203}
204
205/// Kernel fusion analyzer and optimizer.
206pub struct FusionOptimizer {
207    config: FusionConfig,
208    cost_model: FusionCostModel,
209    candidates: Vec<FusionCandidate>,
210}
211
212impl FusionOptimizer {
213    /// Create a new fusion optimizer.
214    pub fn new(config: FusionConfig) -> Self {
215        Self {
216            config,
217            cost_model: FusionCostModel::default(),
218            candidates: Vec::new(),
219        }
220    }
221
222    /// Create with custom cost model.
223    pub fn with_cost_model(config: FusionConfig, cost_model: FusionCostModel) -> Self {
224        Self {
225            config,
226            cost_model,
227            candidates: Vec::new(),
228        }
229    }
230
231    /// Analyze graph and identify fusion opportunities.
232    pub fn analyze(&mut self, graph: &EinsumGraph) -> Vec<FusionCandidate> {
233        self.candidates.clear();
234
235        if self.config.enable_patterns {
236            self.find_pattern_fusions(graph);
237        }
238
239        if self.config.enable_vertical {
240            self.find_vertical_fusions(graph);
241        }
242
243        if self.config.enable_horizontal {
244            self.find_horizontal_fusions(graph);
245        }
246
247        // Sort by benefit score
248        self.candidates.sort_by(|a, b| {
249            b.benefit_score
250                .partial_cmp(&a.benefit_score)
251                .unwrap_or(std::cmp::Ordering::Equal)
252        });
253
254        self.candidates.clone()
255    }
256
257    /// Find pattern-based fusion opportunities.
258    fn find_pattern_fusions(&mut self, graph: &EinsumGraph) {
259        // Look for common patterns like MatMul + Bias + Activation
260        for node_id in 0..graph.nodes.len() {
261            let node_id = NodeId(node_id);
262            let node = &graph.nodes[node_id.0];
263
264            // Example: Look for MatMul followed by elementwise operations
265            if matches!(node.op, OpType::Einsum { .. }) {
266                // Check if output is consumed by elementwise op
267                let consumers = self.find_consumers(graph, node_id);
268                for consumer in consumers {
269                    let consumer_node = &graph.nodes[consumer.0];
270                    if matches!(
271                        consumer_node.op,
272                        OpType::ElemUnary { .. } | OpType::ElemBinary { .. }
273                    ) {
274                        let benefit = self.estimate_pattern_benefit(2, 1024); // Example sizes
275
276                        if benefit >= self.config.min_benefit_score {
277                            self.candidates.push(FusionCandidate {
278                                nodes: vec![node_id, consumer],
279                                pattern: FusionPattern::MatMulActivation,
280                                benefit_score: benefit,
281                                memory_savings: 1024 * 4, // Rough estimate
282                                compute_savings: 0.0,
283                            });
284                        }
285                    }
286                }
287            }
288        }
289    }
290
291    /// Find vertical fusion opportunities (producer-consumer chains).
292    fn find_vertical_fusions(&mut self, graph: &EinsumGraph) {
293        for node_id in 0..graph.nodes.len() {
294            let node_id = NodeId(node_id);
295            let consumers = self.find_consumers(graph, node_id);
296
297            // If node has exactly one consumer, consider vertical fusion
298            if consumers.len() == 1 {
299                let consumer = consumers[0];
300                if self.can_fuse_vertically(graph, node_id, consumer) {
301                    let benefit = self.cost_model.fusion_benefit(2, 1024);
302
303                    if benefit >= self.config.min_benefit_score {
304                        self.candidates.push(FusionCandidate {
305                            nodes: vec![node_id, consumer],
306                            pattern: FusionPattern::ElementwiseChain,
307                            benefit_score: benefit,
308                            memory_savings: 1024 * 4,
309                            compute_savings: 0.0,
310                        });
311                    }
312                }
313            }
314        }
315    }
316
317    /// Find horizontal fusion opportunities (parallel independent ops).
318    fn find_horizontal_fusions(&mut self, graph: &EinsumGraph) {
319        let _independent_groups: Vec<Vec<NodeId>> = Vec::new();
320
321        // Group nodes by their depth in the graph
322        let mut depth_groups: HashMap<usize, Vec<NodeId>> = HashMap::new();
323
324        for node_id in 0..graph.nodes.len() {
325            let depth = self.compute_depth(graph, NodeId(node_id));
326            depth_groups.entry(depth).or_default().push(NodeId(node_id));
327        }
328
329        // Within each depth level, find independent operations
330        for (_, nodes) in depth_groups {
331            if nodes.len() >= 2 {
332                // Check for independence and similar operation types
333                for i in 0..nodes.len() {
334                    for j in i + 1..nodes.len() {
335                        if self.are_independent(graph, nodes[i], nodes[j])
336                            && self.have_similar_ops(graph, nodes[i], nodes[j])
337                        {
338                            let benefit = self.cost_model.fusion_benefit(2, 512);
339
340                            if benefit >= self.config.min_benefit_score {
341                                self.candidates.push(FusionCandidate {
342                                    nodes: vec![nodes[i], nodes[j]],
343                                    pattern: FusionPattern::ParallelReductions,
344                                    benefit_score: benefit * 0.8, // Slightly lower benefit for horizontal
345                                    memory_savings: 512 * 4,
346                                    compute_savings: 0.0,
347                                });
348                            }
349                        }
350                    }
351                }
352            }
353        }
354    }
355
356    /// Check if vertical fusion is possible.
357    fn can_fuse_vertically(
358        &self,
359        _graph: &EinsumGraph,
360        _producer: NodeId,
361        _consumer: NodeId,
362    ) -> bool {
363        // Basic checks:
364        // 1. No other consumers of producer
365        // 2. Compatible operations
366        // 3. No intermediate materialization required
367        // 4. Within size limits
368        true // Simplified for now
369    }
370
371    /// Check if two nodes are independent.
372    fn are_independent(&self, graph: &EinsumGraph, a: NodeId, b: NodeId) -> bool {
373        // Check if there's no data dependency between a and b
374        let a_deps = self.get_all_dependencies(graph, a);
375        let b_deps = self.get_all_dependencies(graph, b);
376
377        !a_deps.contains(&b) && !b_deps.contains(&a)
378    }
379
380    /// Check if two nodes have similar operations (for horizontal fusion).
381    fn have_similar_ops(&self, graph: &EinsumGraph, a: NodeId, b: NodeId) -> bool {
382        let op_a = &graph.nodes[a.0].op;
383        let op_b = &graph.nodes[b.0].op;
384
385        std::mem::discriminant(op_a) == std::mem::discriminant(op_b)
386    }
387
388    /// Find all nodes that consume the output of a given node.
389    fn find_consumers(&self, graph: &EinsumGraph, producer: NodeId) -> Vec<NodeId> {
390        let mut consumers = Vec::new();
391
392        for (i, node) in graph.nodes.iter().enumerate() {
393            if node.inputs.iter().any(|&n| NodeId(n) == producer) {
394                consumers.push(NodeId(i));
395            }
396        }
397
398        consumers
399    }
400
401    /// Get all transitive dependencies of a node.
402    fn get_all_dependencies(&self, graph: &EinsumGraph, node_id: NodeId) -> HashSet<NodeId> {
403        let mut deps = HashSet::new();
404        let mut to_visit = vec![node_id];
405
406        while let Some(current) = to_visit.pop() {
407            if deps.contains(&current) {
408                continue;
409            }
410            deps.insert(current);
411
412            let node = &graph.nodes[current.0];
413            for &input in &node.inputs {
414                to_visit.push(NodeId(input));
415            }
416        }
417
418        deps
419    }
420
421    /// Compute the depth of a node in the graph.
422    #[allow(clippy::only_used_in_recursion)]
423    fn compute_depth(&self, graph: &EinsumGraph, node_id: NodeId) -> usize {
424        let node = &graph.nodes[node_id.0];
425
426        if node.inputs.is_empty() {
427            0
428        } else {
429            1 + node
430                .inputs
431                .iter()
432                .map(|&input| self.compute_depth(graph, NodeId(input)))
433                .max()
434                .unwrap_or(0)
435        }
436    }
437
438    /// Estimate benefit of fusing a pattern.
439    fn estimate_pattern_benefit(&self, num_ops: usize, data_size: usize) -> f64 {
440        match self.config.strategy {
441            FusionStrategy::Aggressive => self.cost_model.fusion_benefit(num_ops, data_size) * 1.2,
442            FusionStrategy::Conservative => {
443                self.cost_model.fusion_benefit(num_ops, data_size) * 0.8
444            }
445            FusionStrategy::Balanced => self.cost_model.fusion_benefit(num_ops, data_size),
446            FusionStrategy::MemoryAware => {
447                let base_benefit = self.cost_model.fusion_benefit(num_ops, data_size);
448                // Prioritize memory savings
449                base_benefit * 1.5
450            }
451        }
452    }
453
454    /// Apply fusion candidates to create optimized graph.
455    pub fn apply_fusions(
456        &self,
457        graph: &EinsumGraph,
458        _candidates: &[FusionCandidate],
459    ) -> Result<EinsumGraph, FusionError> {
460        // This would create a new graph with fused operations
461        // For now, return a clone
462        Ok(graph.clone())
463    }
464
465    /// Get fusion statistics.
466    pub fn stats(&self) -> FusionStats {
467        let total_candidates = self.candidates.len();
468        let total_memory_savings: usize = self.candidates.iter().map(|c| c.memory_savings).sum();
469        let avg_benefit_score = if total_candidates > 0 {
470            self.candidates.iter().map(|c| c.benefit_score).sum::<f64>() / total_candidates as f64
471        } else {
472            0.0
473        };
474
475        let mut pattern_counts = HashMap::new();
476        for candidate in &self.candidates {
477            *pattern_counts.entry(candidate.pattern).or_insert(0) += 1;
478        }
479
480        FusionStats {
481            total_candidates,
482            total_memory_savings,
483            avg_benefit_score,
484            pattern_distribution: pattern_counts,
485        }
486    }
487}
488
489/// Fusion statistics.
490#[derive(Debug, Clone, Serialize, Deserialize)]
491pub struct FusionStats {
492    /// Total number of fusion candidates found
493    pub total_candidates: usize,
494    /// Total estimated memory savings (bytes)
495    pub total_memory_savings: usize,
496    /// Average benefit score
497    pub avg_benefit_score: f64,
498    /// Distribution of fusion patterns
499    pub pattern_distribution: HashMap<FusionPattern, usize>,
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use tensorlogic_ir::EinsumNode;
506
507    fn create_test_graph() -> EinsumGraph {
508        let mut graph = EinsumGraph::new();
509
510        // Add some test nodes
511        graph.nodes.push(EinsumNode {
512            op: OpType::Einsum {
513                spec: "ij,jk->ik".to_string(),
514            },
515            inputs: vec![],
516            outputs: vec![0],
517            metadata: Default::default(),
518        });
519
520        graph.nodes.push(EinsumNode {
521            op: OpType::ElemUnary {
522                op: "relu".to_string(),
523            },
524            inputs: vec![0],
525            outputs: vec![1],
526            metadata: Default::default(),
527        });
528
529        graph
530    }
531
532    #[test]
533    fn test_fusion_config() {
534        let config = FusionConfig::aggressive();
535        assert_eq!(config.strategy, FusionStrategy::Aggressive);
536        assert!(config.max_fusion_size >= FusionConfig::default().max_fusion_size);
537
538        let config = FusionConfig::conservative();
539        assert_eq!(config.strategy, FusionStrategy::Conservative);
540    }
541
542    #[test]
543    fn test_cost_model() {
544        let model = FusionCostModel::default();
545
546        let benefit = model.fusion_benefit(3, 1024);
547        assert!(benefit > 0.0);
548        assert!(benefit < 1.0);
549
550        // More operations should have higher benefit
551        let benefit_more = model.fusion_benefit(5, 1024);
552        assert!(benefit_more > benefit);
553    }
554
555    #[test]
556    fn test_fusion_optimizer_creation() {
557        let config = FusionConfig::default();
558        let optimizer = FusionOptimizer::new(config);
559        assert_eq!(optimizer.candidates.len(), 0);
560    }
561
562    #[test]
563    fn test_fusion_analysis() {
564        let graph = create_test_graph();
565        // Use aggressive config to ensure we find candidates
566        let config = FusionConfig {
567            min_benefit_score: 0.0,
568            ..FusionConfig::default()
569        };
570        let mut optimizer = FusionOptimizer::new(config);
571
572        let candidates = optimizer.analyze(&graph);
573        // Should find at least one fusion opportunity (matmul + relu)
574        assert!(!candidates.is_empty());
575    }
576
577    #[test]
578    fn test_consumer_finding() {
579        let graph = create_test_graph();
580        let optimizer = FusionOptimizer::new(FusionConfig::default());
581
582        let consumers = optimizer.find_consumers(&graph, NodeId(0));
583        assert_eq!(consumers.len(), 1);
584        assert_eq!(consumers[0], NodeId(1));
585    }
586
587    #[test]
588    fn test_depth_computation() {
589        let graph = create_test_graph();
590        let optimizer = FusionOptimizer::new(FusionConfig::default());
591
592        assert_eq!(optimizer.compute_depth(&graph, NodeId(0)), 0);
593        assert_eq!(optimizer.compute_depth(&graph, NodeId(1)), 1);
594    }
595
596    #[test]
597    fn test_independence_check() {
598        let mut graph = create_test_graph();
599
600        // Add independent node
601        graph.nodes.push(EinsumNode {
602            op: OpType::ElemUnary {
603                op: "tanh".to_string(),
604            },
605            inputs: vec![],
606            outputs: vec![2],
607            metadata: Default::default(),
608        });
609
610        let optimizer = FusionOptimizer::new(FusionConfig::default());
611
612        // Node 1 depends on Node 0
613        assert!(!optimizer.are_independent(&graph, NodeId(0), NodeId(1)));
614
615        // Node 0 and Node 2 are independent
616        assert!(optimizer.are_independent(&graph, NodeId(0), NodeId(2)));
617    }
618
619    #[test]
620    fn test_fusion_stats() {
621        let graph = create_test_graph();
622        // Use aggressive config to ensure we find candidates
623        let config = FusionConfig {
624            min_benefit_score: 0.0,
625            ..FusionConfig::default()
626        };
627        let mut optimizer = FusionOptimizer::new(config);
628
629        optimizer.analyze(&graph);
630        let stats = optimizer.stats();
631
632        assert!(stats.total_candidates > 0);
633        assert!(stats.avg_benefit_score >= 0.0);
634    }
635
636    #[test]
637    fn test_similar_ops_check() {
638        let graph = create_test_graph();
639        let optimizer = FusionOptimizer::new(FusionConfig::default());
640
641        // Same operation type
642        assert!(optimizer.have_similar_ops(&graph, NodeId(0), NodeId(0)));
643
644        // Different operation types
645        assert!(!optimizer.have_similar_ops(&graph, NodeId(0), NodeId(1)));
646    }
647}