Skip to main content

tensorlogic_infer/
fusion.rs

1//! Advanced kernel fusion for optimized execution.
2//!
3//! This module provides sophisticated kernel fusion capabilities:
4//! - Pattern-based fusion (common operator patterns)
5//! - Vertical fusion (producer-consumer chains)
6//! - Horizontal fusion (independent parallel operations)
7//! - Loop fusion for reductions
8//! - Memory bandwidth-aware fusion decisions
9//! - Cost modeling for fusion trade-offs
10
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use tensorlogic_ir::{EinsumGraph, OpType};
14use thiserror::Error;
15
16/// Node identifier (0-based index into graph.nodes).
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
18pub struct NodeId(pub usize);
19
20/// Fusion-related errors.
21#[derive(Error, Debug, Clone, PartialEq)]
22pub enum FusionError {
23    #[error("Fusion would create a cycle in the graph")]
24    WouldCreateCycle,
25
26    #[error("Incompatible operations for fusion: {0:?} and {1:?}")]
27    IncompatibleOps(OpType, OpType),
28
29    #[error("Fusion exceeds resource limits: {0}")]
30    ResourceLimitExceeded(String),
31
32    #[error("Invalid fusion pattern")]
33    InvalidPattern,
34}
35
36/// Fusion pattern types.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
38pub enum FusionPattern {
39    /// Matrix multiplication followed by bias addition
40    MatMulBias,
41    /// Matrix multiplication followed by activation (ReLU, tanh, etc.)
42    MatMulActivation,
43    /// Bias addition followed by activation
44    BiasActivation,
45    /// BatchNorm + ReLU fusion
46    BatchNormReLU,
47    /// Conv + BatchNorm + ReLU
48    ConvBNReLU,
49    /// Elementwise operations chain
50    ElementwiseChain,
51    /// Reduction followed by elementwise
52    ReduceElementwise,
53    /// Multiple independent reductions (horizontal fusion)
54    ParallelReductions,
55    /// Broadcast followed by elementwise
56    BroadcastElementwise,
57}
58
59/// Fusion strategy.
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
61pub enum FusionStrategy {
62    /// Conservative - only fuse proven beneficial patterns
63    Conservative,
64    /// Aggressive - fuse as much as possible
65    Aggressive,
66    /// Balanced - consider cost model
67    Balanced,
68    /// Memory-aware - prioritize memory bandwidth reduction
69    MemoryAware,
70}
71
72/// Fusion candidate representing potential fusion opportunity.
73#[derive(Debug, Clone, PartialEq)]
74pub struct FusionCandidate {
75    /// Nodes to be fused
76    pub nodes: Vec<NodeId>,
77    /// Pattern type
78    pub pattern: FusionPattern,
79    /// Estimated benefit score (higher is better)
80    pub benefit_score: f64,
81    /// Estimated memory savings (bytes)
82    pub memory_savings: usize,
83    /// Estimated compute reduction (FLOPS)
84    pub compute_savings: f64,
85}
86
87/// Fusion configuration.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct FusionConfig {
90    /// Fusion strategy
91    pub strategy: FusionStrategy,
92    /// Maximum nodes per fused kernel
93    pub max_fusion_size: usize,
94    /// Enable pattern-based fusion
95    pub enable_patterns: bool,
96    /// Enable vertical fusion
97    pub enable_vertical: bool,
98    /// Enable horizontal fusion
99    pub enable_horizontal: bool,
100    /// Enable loop fusion
101    pub enable_loop_fusion: bool,
102    /// Memory bandwidth threshold (bytes/s)
103    pub memory_bandwidth_threshold: Option<f64>,
104    /// Minimum benefit score to apply fusion
105    pub min_benefit_score: f64,
106}
107
108impl Default for FusionConfig {
109    fn default() -> Self {
110        Self {
111            strategy: FusionStrategy::Balanced,
112            max_fusion_size: 8,
113            enable_patterns: true,
114            enable_vertical: true,
115            enable_horizontal: true,
116            enable_loop_fusion: true,
117            memory_bandwidth_threshold: None,
118            min_benefit_score: 0.1,
119        }
120    }
121}
122
123impl FusionConfig {
124    /// Create aggressive fusion configuration.
125    pub fn aggressive() -> Self {
126        Self {
127            strategy: FusionStrategy::Aggressive,
128            max_fusion_size: 16,
129            min_benefit_score: 0.0,
130            ..Default::default()
131        }
132    }
133
134    /// Create conservative fusion configuration.
135    pub fn conservative() -> Self {
136        Self {
137            strategy: FusionStrategy::Conservative,
138            max_fusion_size: 4,
139            enable_horizontal: false,
140            enable_loop_fusion: false,
141            min_benefit_score: 0.3,
142            ..Default::default()
143        }
144    }
145
146    /// Create memory-aware fusion configuration.
147    pub fn memory_aware() -> Self {
148        Self {
149            strategy: FusionStrategy::MemoryAware,
150            memory_bandwidth_threshold: Some(100e9), // 100 GB/s
151            ..Default::default()
152        }
153    }
154}
155
156/// Cost model for fusion decisions.
157#[derive(Debug, Clone)]
158pub struct FusionCostModel {
159    /// Cost of memory access (relative units)
160    pub memory_access_cost: f64,
161    /// Cost of compute operation (relative units)
162    pub compute_cost: f64,
163    /// Cost of kernel launch overhead
164    pub kernel_launch_cost: f64,
165    /// Memory bandwidth (bytes/second)
166    pub memory_bandwidth: f64,
167}
168
169impl Default for FusionCostModel {
170    fn default() -> Self {
171        Self {
172            memory_access_cost: 1.0,
173            compute_cost: 0.1,
174            kernel_launch_cost: 10.0,
175            memory_bandwidth: 100e9, // 100 GB/s typical
176        }
177    }
178}
179
180impl FusionCostModel {
181    /// Estimate cost of executing operations separately.
182    pub fn cost_separate(&self, num_ops: usize, data_size: usize) -> f64 {
183        let memory_cost = self.memory_access_cost * data_size as f64 * num_ops as f64;
184        let launch_cost = self.kernel_launch_cost * num_ops as f64;
185        memory_cost + launch_cost
186    }
187
188    /// Estimate cost of executing operations fused.
189    pub fn cost_fused(&self, num_ops: usize, data_size: usize) -> f64 {
190        // Fused operations read data once, write once
191        let memory_cost = self.memory_access_cost * data_size as f64 * 2.0;
192        let launch_cost = self.kernel_launch_cost;
193        let compute_overhead = self.compute_cost * num_ops as f64; // slight overhead
194        memory_cost + launch_cost + compute_overhead
195    }
196
197    /// Calculate fusion benefit.
198    pub fn fusion_benefit(&self, num_ops: usize, data_size: usize) -> f64 {
199        let separate_cost = self.cost_separate(num_ops, data_size);
200        let fused_cost = self.cost_fused(num_ops, data_size);
201        (separate_cost - fused_cost) / separate_cost
202    }
203}
204
205/// Kernel fusion analyzer and optimizer.
206pub struct FusionOptimizer {
207    config: FusionConfig,
208    cost_model: FusionCostModel,
209    candidates: Vec<FusionCandidate>,
210}
211
212impl FusionOptimizer {
213    /// Create a new fusion optimizer.
214    pub fn new(config: FusionConfig) -> Self {
215        Self {
216            config,
217            cost_model: FusionCostModel::default(),
218            candidates: Vec::new(),
219        }
220    }
221
222    /// Create with custom cost model.
223    pub fn with_cost_model(config: FusionConfig, cost_model: FusionCostModel) -> Self {
224        Self {
225            config,
226            cost_model,
227            candidates: Vec::new(),
228        }
229    }
230
231    /// Analyze graph and identify fusion opportunities.
232    pub fn analyze(&mut self, graph: &EinsumGraph) -> Vec<FusionCandidate> {
233        self.candidates.clear();
234
235        if self.config.enable_patterns {
236            self.find_pattern_fusions(graph);
237        }
238
239        if self.config.enable_vertical {
240            self.find_vertical_fusions(graph);
241        }
242
243        if self.config.enable_horizontal {
244            self.find_horizontal_fusions(graph);
245        }
246
247        // Sort by benefit score
248        self.candidates
249            .sort_by(|a, b| b.benefit_score.partial_cmp(&a.benefit_score).unwrap());
250
251        self.candidates.clone()
252    }
253
254    /// Find pattern-based fusion opportunities.
255    fn find_pattern_fusions(&mut self, graph: &EinsumGraph) {
256        // Look for common patterns like MatMul + Bias + Activation
257        for node_id in 0..graph.nodes.len() {
258            let node_id = NodeId(node_id);
259            let node = &graph.nodes[node_id.0];
260
261            // Example: Look for MatMul followed by elementwise operations
262            if matches!(node.op, OpType::Einsum { .. }) {
263                // Check if output is consumed by elementwise op
264                let consumers = self.find_consumers(graph, node_id);
265                for consumer in consumers {
266                    let consumer_node = &graph.nodes[consumer.0];
267                    if matches!(
268                        consumer_node.op,
269                        OpType::ElemUnary { .. } | OpType::ElemBinary { .. }
270                    ) {
271                        let benefit = self.estimate_pattern_benefit(2, 1024); // Example sizes
272
273                        if benefit >= self.config.min_benefit_score {
274                            self.candidates.push(FusionCandidate {
275                                nodes: vec![node_id, consumer],
276                                pattern: FusionPattern::MatMulActivation,
277                                benefit_score: benefit,
278                                memory_savings: 1024 * 4, // Rough estimate
279                                compute_savings: 0.0,
280                            });
281                        }
282                    }
283                }
284            }
285        }
286    }
287
288    /// Find vertical fusion opportunities (producer-consumer chains).
289    fn find_vertical_fusions(&mut self, graph: &EinsumGraph) {
290        for node_id in 0..graph.nodes.len() {
291            let node_id = NodeId(node_id);
292            let consumers = self.find_consumers(graph, node_id);
293
294            // If node has exactly one consumer, consider vertical fusion
295            if consumers.len() == 1 {
296                let consumer = consumers[0];
297                if self.can_fuse_vertically(graph, node_id, consumer) {
298                    let benefit = self.cost_model.fusion_benefit(2, 1024);
299
300                    if benefit >= self.config.min_benefit_score {
301                        self.candidates.push(FusionCandidate {
302                            nodes: vec![node_id, consumer],
303                            pattern: FusionPattern::ElementwiseChain,
304                            benefit_score: benefit,
305                            memory_savings: 1024 * 4,
306                            compute_savings: 0.0,
307                        });
308                    }
309                }
310            }
311        }
312    }
313
314    /// Find horizontal fusion opportunities (parallel independent ops).
315    fn find_horizontal_fusions(&mut self, graph: &EinsumGraph) {
316        let _independent_groups: Vec<Vec<NodeId>> = Vec::new();
317
318        // Group nodes by their depth in the graph
319        let mut depth_groups: HashMap<usize, Vec<NodeId>> = HashMap::new();
320
321        for node_id in 0..graph.nodes.len() {
322            let depth = self.compute_depth(graph, NodeId(node_id));
323            depth_groups.entry(depth).or_default().push(NodeId(node_id));
324        }
325
326        // Within each depth level, find independent operations
327        for (_, nodes) in depth_groups {
328            if nodes.len() >= 2 {
329                // Check for independence and similar operation types
330                for i in 0..nodes.len() {
331                    for j in i + 1..nodes.len() {
332                        if self.are_independent(graph, nodes[i], nodes[j])
333                            && self.have_similar_ops(graph, nodes[i], nodes[j])
334                        {
335                            let benefit = self.cost_model.fusion_benefit(2, 512);
336
337                            if benefit >= self.config.min_benefit_score {
338                                self.candidates.push(FusionCandidate {
339                                    nodes: vec![nodes[i], nodes[j]],
340                                    pattern: FusionPattern::ParallelReductions,
341                                    benefit_score: benefit * 0.8, // Slightly lower benefit for horizontal
342                                    memory_savings: 512 * 4,
343                                    compute_savings: 0.0,
344                                });
345                            }
346                        }
347                    }
348                }
349            }
350        }
351    }
352
353    /// Check if vertical fusion is possible.
354    fn can_fuse_vertically(
355        &self,
356        _graph: &EinsumGraph,
357        _producer: NodeId,
358        _consumer: NodeId,
359    ) -> bool {
360        // Basic checks:
361        // 1. No other consumers of producer
362        // 2. Compatible operations
363        // 3. No intermediate materialization required
364        // 4. Within size limits
365        true // Simplified for now
366    }
367
368    /// Check if two nodes are independent.
369    fn are_independent(&self, graph: &EinsumGraph, a: NodeId, b: NodeId) -> bool {
370        // Check if there's no data dependency between a and b
371        let a_deps = self.get_all_dependencies(graph, a);
372        let b_deps = self.get_all_dependencies(graph, b);
373
374        !a_deps.contains(&b) && !b_deps.contains(&a)
375    }
376
377    /// Check if two nodes have similar operations (for horizontal fusion).
378    fn have_similar_ops(&self, graph: &EinsumGraph, a: NodeId, b: NodeId) -> bool {
379        let op_a = &graph.nodes[a.0].op;
380        let op_b = &graph.nodes[b.0].op;
381
382        std::mem::discriminant(op_a) == std::mem::discriminant(op_b)
383    }
384
385    /// Find all nodes that consume the output of a given node.
386    fn find_consumers(&self, graph: &EinsumGraph, producer: NodeId) -> Vec<NodeId> {
387        let mut consumers = Vec::new();
388
389        for (i, node) in graph.nodes.iter().enumerate() {
390            if node.inputs.iter().any(|&n| NodeId(n) == producer) {
391                consumers.push(NodeId(i));
392            }
393        }
394
395        consumers
396    }
397
398    /// Get all transitive dependencies of a node.
399    fn get_all_dependencies(&self, graph: &EinsumGraph, node_id: NodeId) -> HashSet<NodeId> {
400        let mut deps = HashSet::new();
401        let mut to_visit = vec![node_id];
402
403        while let Some(current) = to_visit.pop() {
404            if deps.contains(&current) {
405                continue;
406            }
407            deps.insert(current);
408
409            let node = &graph.nodes[current.0];
410            for &input in &node.inputs {
411                to_visit.push(NodeId(input));
412            }
413        }
414
415        deps
416    }
417
418    /// Compute the depth of a node in the graph.
419    #[allow(clippy::only_used_in_recursion)]
420    fn compute_depth(&self, graph: &EinsumGraph, node_id: NodeId) -> usize {
421        let node = &graph.nodes[node_id.0];
422
423        if node.inputs.is_empty() {
424            0
425        } else {
426            1 + node
427                .inputs
428                .iter()
429                .map(|&input| self.compute_depth(graph, NodeId(input)))
430                .max()
431                .unwrap_or(0)
432        }
433    }
434
435    /// Estimate benefit of fusing a pattern.
436    fn estimate_pattern_benefit(&self, num_ops: usize, data_size: usize) -> f64 {
437        match self.config.strategy {
438            FusionStrategy::Aggressive => self.cost_model.fusion_benefit(num_ops, data_size) * 1.2,
439            FusionStrategy::Conservative => {
440                self.cost_model.fusion_benefit(num_ops, data_size) * 0.8
441            }
442            FusionStrategy::Balanced => self.cost_model.fusion_benefit(num_ops, data_size),
443            FusionStrategy::MemoryAware => {
444                let base_benefit = self.cost_model.fusion_benefit(num_ops, data_size);
445                // Prioritize memory savings
446                base_benefit * 1.5
447            }
448        }
449    }
450
451    /// Apply fusion candidates to create optimized graph.
452    pub fn apply_fusions(
453        &self,
454        graph: &EinsumGraph,
455        _candidates: &[FusionCandidate],
456    ) -> Result<EinsumGraph, FusionError> {
457        // This would create a new graph with fused operations
458        // For now, return a clone
459        Ok(graph.clone())
460    }
461
462    /// Get fusion statistics.
463    pub fn stats(&self) -> FusionStats {
464        let total_candidates = self.candidates.len();
465        let total_memory_savings: usize = self.candidates.iter().map(|c| c.memory_savings).sum();
466        let avg_benefit_score = if total_candidates > 0 {
467            self.candidates.iter().map(|c| c.benefit_score).sum::<f64>() / total_candidates as f64
468        } else {
469            0.0
470        };
471
472        let mut pattern_counts = HashMap::new();
473        for candidate in &self.candidates {
474            *pattern_counts.entry(candidate.pattern).or_insert(0) += 1;
475        }
476
477        FusionStats {
478            total_candidates,
479            total_memory_savings,
480            avg_benefit_score,
481            pattern_distribution: pattern_counts,
482        }
483    }
484}
485
486/// Fusion statistics.
487#[derive(Debug, Clone, Serialize, Deserialize)]
488pub struct FusionStats {
489    /// Total number of fusion candidates found
490    pub total_candidates: usize,
491    /// Total estimated memory savings (bytes)
492    pub total_memory_savings: usize,
493    /// Average benefit score
494    pub avg_benefit_score: f64,
495    /// Distribution of fusion patterns
496    pub pattern_distribution: HashMap<FusionPattern, usize>,
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502    use tensorlogic_ir::EinsumNode;
503
504    fn create_test_graph() -> EinsumGraph {
505        let mut graph = EinsumGraph::new();
506
507        // Add some test nodes
508        graph.nodes.push(EinsumNode {
509            op: OpType::Einsum {
510                spec: "ij,jk->ik".to_string(),
511            },
512            inputs: vec![],
513            outputs: vec![0],
514            metadata: Default::default(),
515        });
516
517        graph.nodes.push(EinsumNode {
518            op: OpType::ElemUnary {
519                op: "relu".to_string(),
520            },
521            inputs: vec![0],
522            outputs: vec![1],
523            metadata: Default::default(),
524        });
525
526        graph
527    }
528
529    #[test]
530    fn test_fusion_config() {
531        let config = FusionConfig::aggressive();
532        assert_eq!(config.strategy, FusionStrategy::Aggressive);
533        assert!(config.max_fusion_size >= FusionConfig::default().max_fusion_size);
534
535        let config = FusionConfig::conservative();
536        assert_eq!(config.strategy, FusionStrategy::Conservative);
537    }
538
539    #[test]
540    fn test_cost_model() {
541        let model = FusionCostModel::default();
542
543        let benefit = model.fusion_benefit(3, 1024);
544        assert!(benefit > 0.0);
545        assert!(benefit < 1.0);
546
547        // More operations should have higher benefit
548        let benefit_more = model.fusion_benefit(5, 1024);
549        assert!(benefit_more > benefit);
550    }
551
552    #[test]
553    fn test_fusion_optimizer_creation() {
554        let config = FusionConfig::default();
555        let optimizer = FusionOptimizer::new(config);
556        assert_eq!(optimizer.candidates.len(), 0);
557    }
558
559    #[test]
560    fn test_fusion_analysis() {
561        let graph = create_test_graph();
562        // Use aggressive config to ensure we find candidates
563        let config = FusionConfig {
564            min_benefit_score: 0.0,
565            ..FusionConfig::default()
566        };
567        let mut optimizer = FusionOptimizer::new(config);
568
569        let candidates = optimizer.analyze(&graph);
570        // Should find at least one fusion opportunity (matmul + relu)
571        assert!(!candidates.is_empty());
572    }
573
574    #[test]
575    fn test_consumer_finding() {
576        let graph = create_test_graph();
577        let optimizer = FusionOptimizer::new(FusionConfig::default());
578
579        let consumers = optimizer.find_consumers(&graph, NodeId(0));
580        assert_eq!(consumers.len(), 1);
581        assert_eq!(consumers[0], NodeId(1));
582    }
583
584    #[test]
585    fn test_depth_computation() {
586        let graph = create_test_graph();
587        let optimizer = FusionOptimizer::new(FusionConfig::default());
588
589        assert_eq!(optimizer.compute_depth(&graph, NodeId(0)), 0);
590        assert_eq!(optimizer.compute_depth(&graph, NodeId(1)), 1);
591    }
592
593    #[test]
594    fn test_independence_check() {
595        let mut graph = create_test_graph();
596
597        // Add independent node
598        graph.nodes.push(EinsumNode {
599            op: OpType::ElemUnary {
600                op: "tanh".to_string(),
601            },
602            inputs: vec![],
603            outputs: vec![2],
604            metadata: Default::default(),
605        });
606
607        let optimizer = FusionOptimizer::new(FusionConfig::default());
608
609        // Node 1 depends on Node 0
610        assert!(!optimizer.are_independent(&graph, NodeId(0), NodeId(1)));
611
612        // Node 0 and Node 2 are independent
613        assert!(optimizer.are_independent(&graph, NodeId(0), NodeId(2)));
614    }
615
616    #[test]
617    fn test_fusion_stats() {
618        let graph = create_test_graph();
619        // Use aggressive config to ensure we find candidates
620        let config = FusionConfig {
621            min_benefit_score: 0.0,
622            ..FusionConfig::default()
623        };
624        let mut optimizer = FusionOptimizer::new(config);
625
626        optimizer.analyze(&graph);
627        let stats = optimizer.stats();
628
629        assert!(stats.total_candidates > 0);
630        assert!(stats.avg_benefit_score >= 0.0);
631    }
632
633    #[test]
634    fn test_similar_ops_check() {
635        let graph = create_test_graph();
636        let optimizer = FusionOptimizer::new(FusionConfig::default());
637
638        // Same operation type
639        assert!(optimizer.have_similar_ops(&graph, NodeId(0), NodeId(0)));
640
641        // Different operation types
642        assert!(!optimizer.have_similar_ops(&graph, NodeId(0), NodeId(1)));
643    }
644}