Skip to main content

trustformers_core/kernel_fusion/
engine.rs

1//! Kernel fusion engine implementation
2//!
3//! This module contains the main KernelFusionEngine implementation with
4//! pattern matching, constraint verification, and kernel generation logic.
5
6#![allow(unused_variables)] // Kernel fusion engine
7
8use crate::errors::{Result, TrustformersError};
9use crate::kernel_fusion::graph::{ComputationGraph, Device, GraphNode, TensorInfo};
10use crate::kernel_fusion::kernel::{FusedKernel, KernelImplementation};
11use crate::kernel_fusion::operation_types::{FusionConstraint, FusionPattern, OperationType};
12use crate::kernel_fusion::performance::{
13    DeviceCharacteristics, FusionStatistics, OperationCost, PerformanceDatabase,
14};
15use anyhow::anyhow;
16use std::collections::{HashMap, HashSet};
17use std::sync::{Arc, RwLock};
18
19/// Kernel fusion engine
20pub struct KernelFusionEngine {
21    pub patterns: Vec<FusionPattern>,
22    pub constraints: Vec<FusionConstraint>,
23    pub generated_kernels: Arc<RwLock<HashMap<String, FusedKernel>>>,
24    pub performance_database: Arc<RwLock<PerformanceDatabase>>,
25    pub fusion_statistics: Arc<RwLock<FusionStatistics>>,
26}
27
28pub struct FusionOpportunity {
29    pub pattern: FusionPattern,
30    pub node_ids: Vec<String>,
31    pub estimated_benefit: f64,
32    pub constraints_satisfied: bool,
33}
34
35impl KernelFusionEngine {
36    pub fn new() -> Self {
37        let mut engine = Self {
38            patterns: Vec::new(),
39            constraints: Vec::new(),
40            generated_kernels: Arc::new(RwLock::new(HashMap::new())),
41            performance_database: Arc::new(RwLock::new(PerformanceDatabase::default())),
42            fusion_statistics: Arc::new(RwLock::new(FusionStatistics::default())),
43        };
44
45        engine.initialize_default_patterns();
46        engine.initialize_performance_database();
47        engine
48    }
49
50    pub fn analyze_graph(&self, graph: &ComputationGraph) -> Result<Vec<FusionOpportunity>> {
51        let mut opportunities = Vec::new();
52
53        for pattern in &self.patterns {
54            let mut pattern_opportunities = self.find_pattern_matches(graph, pattern)?;
55            opportunities.append(&mut pattern_opportunities);
56        }
57
58        // Sort by estimated benefit (descending)
59        opportunities.sort_by(|a, b| {
60            b.estimated_benefit
61                .partial_cmp(&a.estimated_benefit)
62                .unwrap_or(std::cmp::Ordering::Equal)
63        });
64
65        Ok(opportunities)
66    }
67
68    pub fn fuse_operations(
69        &self,
70        graph: &ComputationGraph,
71        opportunity: &FusionOpportunity,
72    ) -> Result<FusedKernel> {
73        // Verify constraints one more time
74        if !self.verify_fusion_constraints(&opportunity.node_ids, graph)? {
75            return Err(TrustformersError::invalid_operation(
76                "Fusion constraints not satisfied".to_string(),
77            ));
78        }
79
80        // Generate fused kernel
81        let kernel_name = self.generate_kernel_name(&opportunity.pattern);
82        let implementation = self.generate_kernel_implementation(opportunity)?;
83
84        let fused_kernel = FusedKernel::new(
85            format!("fused_{}", uuid::Uuid::new_v4()),
86            kernel_name,
87            opportunity.pattern.clone(),
88            opportunity.node_ids.clone(),
89        )
90        .with_implementation(implementation)
91        .with_speedup(opportunity.estimated_benefit);
92
93        // Store generated kernel
94        self.generated_kernels
95            .write()
96            .expect("generated_kernels lock should not be poisoned")
97            .insert(fused_kernel.id.clone(), fused_kernel.clone());
98
99        // Calculate memory savings from eliminating intermediate tensors
100        let memory_saved = self.calculate_memory_savings(graph, &opportunity.node_ids)?;
101
102        // Update statistics
103        let mut stats = self
104            .fusion_statistics
105            .write()
106            .expect("fusion_statistics lock should not be poisoned");
107        stats.record_successful_fusion(
108            &self.pattern_name(&opportunity.pattern),
109            opportunity.estimated_benefit,
110            memory_saved,
111        );
112
113        Ok(fused_kernel)
114    }
115
116    fn initialize_default_patterns(&mut self) {
117        // Element-wise operation chains
118        self.patterns.push(FusionPattern::ElementWiseChain(vec![
119            OperationType::Add,
120            OperationType::ReLU,
121        ]));
122
123        self.patterns.push(FusionPattern::ElementWiseChain(vec![
124            OperationType::Multiply,
125            OperationType::Add,
126            OperationType::GELU,
127        ]));
128
129        // Linear + activation patterns
130        self.patterns.push(FusionPattern::LinearActivation {
131            matmul: OperationType::MatMul,
132            bias_add: true,
133            activation: Some(OperationType::ReLU),
134        });
135
136        self.patterns.push(FusionPattern::LinearActivation {
137            matmul: OperationType::MatMul,
138            bias_add: true,
139            activation: Some(OperationType::GELU),
140        });
141
142        // Layer normalization patterns
143        self.patterns.push(FusionPattern::BatchNorm {
144            normalize: true,
145            scale: true,
146            shift: true,
147            activation: None,
148        });
149
150        // Attention fusion
151        self.patterns.push(FusionPattern::AttentionFusion {
152            query_key_matmul: true,
153            softmax: true,
154            value_matmul: true,
155            dropout: false,
156        });
157
158        // Reduce-broadcast patterns
159        self.patterns.push(FusionPattern::ReduceBroadcast {
160            reduction: OperationType::Mean,
161            broadcast: OperationType::Broadcast,
162        });
163
164        // Modern transformer fusion patterns
165
166        // RoPE fusion for rotary position embedding
167        self.patterns.push(FusionPattern::RoPEFusion {
168            apply_rope: true,
169            cos_sin_cached: true,
170            dimensions: 128, // Common dimension for RoPE
171        });
172
173        // SwiGLU activation fusion (used in LLaMA, PaLM, etc.)
174        self.patterns.push(FusionPattern::SwiGLU {
175            gate_projection: true,
176            up_projection: true,
177            swish_activation: true,
178            element_wise_multiply: true,
179        });
180
181        // Group normalization fusion
182        self.patterns.push(FusionPattern::GroupNorm {
183            groups: 32,
184            normalize: true,
185            scale: true,
186            shift: true,
187            activation: None,
188        });
189
190        // Optimized flash attention with memory-efficient blocking
191        self.patterns.push(FusionPattern::FlashAttentionOptimized {
192            query_key_matmul: true,
193            scaled_softmax: true,
194            value_matmul: true,
195            causal_mask: true,
196            dropout: false,
197            block_size: 128, // Optimal block size for most hardware
198        });
199
200        // RMSNorm fusion (used in LLaMA and other models)
201        self.patterns.push(FusionPattern::Custom {
202            name: "RMSNorm".to_string(),
203            operations: vec![
204                OperationType::Power,    // x^2
205                OperationType::Mean,     // mean(x^2)
206                OperationType::Add,      // + eps
207                OperationType::Power,    // sqrt (power 0.5)
208                OperationType::Divide,   // x / rms
209                OperationType::Multiply, // * weight
210            ],
211            constraints: vec![
212                FusionConstraint::ShapeCompatible,
213                FusionConstraint::DataTypeCompatible,
214                FusionConstraint::Contiguous,
215            ],
216        });
217
218        // Initialize default constraints
219        self.constraints.extend(vec![
220            FusionConstraint::ShapeCompatible,
221            FusionConstraint::DataTypeCompatible,
222            FusionConstraint::DeviceCompatible,
223            FusionConstraint::MaxOperations(8),
224            FusionConstraint::MaxMemoryUsage(1024 * 1024 * 1024), // 1GB
225            FusionConstraint::Contiguous,
226        ]);
227    }
228
229    fn initialize_performance_database(&mut self) {
230        let mut db = self
231            .performance_database
232            .write()
233            .expect("performance_database lock should not be poisoned");
234
235        // Add operation costs for common operations
236        db.add_operation_cost(
237            OperationType::Add,
238            OperationCost::new(1.0, 0.1).with_launch_overhead(500),
239        );
240
241        db.add_operation_cost(
242            OperationType::Multiply,
243            OperationCost::new(1.0, 0.1).with_launch_overhead(500),
244        );
245
246        db.add_operation_cost(
247            OperationType::MatMul,
248            OperationCost::new(100.0, 1.0).with_launch_overhead(2000),
249        );
250
251        db.add_operation_cost(
252            OperationType::ReLU,
253            OperationCost::new(1.0, 0.05).with_launch_overhead(300),
254        );
255
256        db.add_operation_cost(
257            OperationType::GELU,
258            OperationCost::new(10.0, 0.1).with_launch_overhead(800),
259        );
260
261        // Add device characteristics
262        db.add_device_characteristics(Device::CPU, DeviceCharacteristics::cpu_characteristics());
263        db.add_device_characteristics(Device::GPU(0), DeviceCharacteristics::gpu_characteristics());
264    }
265
266    fn find_pattern_matches(
267        &self,
268        graph: &ComputationGraph,
269        pattern: &FusionPattern,
270    ) -> Result<Vec<FusionOpportunity>> {
271        match pattern {
272            FusionPattern::ElementWiseChain(ops) => self.find_elementwise_chains(graph, ops),
273            FusionPattern::LinearActivation { .. } => {
274                self.find_linear_activation_patterns(graph, pattern)
275            },
276            FusionPattern::AttentionFusion { .. } => self.find_attention_patterns(graph),
277            // Add more pattern matching logic for other patterns
278            _ => Ok(Vec::new()), // Placeholder for unimplemented patterns
279        }
280    }
281
282    fn find_elementwise_chains(
283        &self,
284        graph: &ComputationGraph,
285        target_ops: &[OperationType],
286    ) -> Result<Vec<FusionOpportunity>> {
287        let mut opportunities = Vec::new();
288
289        // Look for sequences of element-wise operations that match the target pattern
290        for node_id in &graph.execution_order {
291            if let Some(node) = graph.get_node(node_id) {
292                if node.operation == target_ops[0] {
293                    // Try to match the complete chain starting from this node
294                    let mut chain = vec![node_id.clone()];
295                    let mut current_id = node_id.clone();
296
297                    for target_op in target_ops.iter().skip(1) {
298                        // Find the next node in the chain
299                        if let Some(next_id) =
300                            self.find_next_operation(&current_id, target_op.clone(), graph)
301                        {
302                            chain.push(next_id.clone());
303                            current_id = next_id;
304                        } else {
305                            break;
306                        }
307                    }
308
309                    if chain.len() == target_ops.len() {
310                        let benefit = self.estimate_fusion_benefit(&chain, graph)?;
311                        let constraints_satisfied =
312                            self.verify_fusion_constraints(&chain, graph)?;
313
314                        opportunities.push(FusionOpportunity {
315                            pattern: FusionPattern::ElementWiseChain(target_ops.to_vec()),
316                            node_ids: chain,
317                            estimated_benefit: benefit,
318                            constraints_satisfied,
319                        });
320                    }
321                }
322            }
323        }
324
325        Ok(opportunities)
326    }
327
328    fn find_linear_activation_patterns(
329        &self,
330        graph: &ComputationGraph,
331        pattern: &FusionPattern,
332    ) -> Result<Vec<FusionOpportunity>> {
333        let mut opportunities = Vec::new();
334
335        // Look for MatMul -> Add -> Activation patterns
336        for node_id in &graph.execution_order {
337            if let Some(node) = graph.get_node(node_id) {
338                if node.operation == OperationType::MatMul {
339                    let mut chain = vec![node_id.clone()];
340
341                    // Look for bias add
342                    if let Some(add_id) =
343                        self.find_next_operation(node_id, OperationType::Add, graph)
344                    {
345                        chain.push(add_id.clone());
346
347                        // Look for activation
348                        if let FusionPattern::LinearActivation {
349                            activation: Some(act_type),
350                            ..
351                        } = pattern
352                        {
353                            if let Some(act_id) =
354                                self.find_next_operation(&add_id, act_type.clone(), graph)
355                            {
356                                chain.push(act_id);
357                            }
358                        }
359                    }
360
361                    if chain.len() >= 2 {
362                        // At least MatMul + Add
363                        let benefit = self.estimate_fusion_benefit(&chain, graph)?;
364                        let constraints_satisfied =
365                            self.verify_fusion_constraints(&chain, graph)?;
366
367                        opportunities.push(FusionOpportunity {
368                            pattern: pattern.clone(),
369                            node_ids: chain,
370                            estimated_benefit: benefit,
371                            constraints_satisfied,
372                        });
373                    }
374                }
375            }
376        }
377
378        Ok(opportunities)
379    }
380
381    fn find_attention_patterns(&self, graph: &ComputationGraph) -> Result<Vec<FusionOpportunity>> {
382        // Placeholder implementation for attention pattern detection
383        // In a full implementation, this would look for Q*K^T -> Softmax -> *V patterns
384        Ok(Vec::new())
385    }
386
387    fn find_next_operation(
388        &self,
389        current_id: &str,
390        target_op: OperationType,
391        graph: &ComputationGraph,
392    ) -> Option<String> {
393        // Find consumers of the current node
394        for (node_id, dependencies) in &graph.edges {
395            if dependencies.contains(&current_id.to_string()) {
396                if let Some(node) = graph.get_node(node_id) {
397                    if node.operation == target_op {
398                        return Some(node_id.clone());
399                    }
400                }
401            }
402        }
403        None
404    }
405
406    fn verify_fusion_constraints(
407        &self,
408        node_ids: &[String],
409        graph: &ComputationGraph,
410    ) -> Result<bool> {
411        let nodes: Vec<&GraphNode> = node_ids.iter().filter_map(|id| graph.get_node(id)).collect();
412
413        if nodes.len() != node_ids.len() {
414            return Ok(false); // Some nodes not found
415        }
416
417        for constraint in &self.constraints {
418            match constraint {
419                FusionConstraint::ShapeCompatible if !self.check_shape_compatibility(&nodes)? => {
420                    return Ok(false);
421                },
422                FusionConstraint::DataTypeCompatible
423                    if !self.check_data_type_compatibility(&nodes)? =>
424                {
425                    return Ok(false);
426                },
427                FusionConstraint::DeviceCompatible
428                    if !self.check_device_compatibility(&nodes)? =>
429                {
430                    return Ok(false);
431                },
432                FusionConstraint::MaxOperations(max_ops) if nodes.len() > *max_ops => {
433                    return Ok(false);
434                },
435                FusionConstraint::Contiguous if !self.check_contiguity(node_ids, graph)? => {
436                    return Ok(false);
437                },
438                // Add more constraint checks as needed
439                _ => {}, // Placeholder for other constraints
440            }
441        }
442
443        Ok(true)
444    }
445
446    fn check_shape_compatibility(&self, nodes: &[&GraphNode]) -> Result<bool> {
447        if nodes.is_empty() {
448            return Ok(true);
449        }
450
451        // Check if all output shapes are compatible (can be broadcasted or are identical)
452        let first_output_shape =
453            &nodes[0].outputs.first().ok_or_else(|| anyhow!("Node has no outputs"))?.shape;
454
455        for node in nodes.iter().skip(1) {
456            let output_shape =
457                &node.outputs.first().ok_or_else(|| anyhow!("Node has no outputs"))?.shape;
458
459            if !self.shapes_broadcastable(first_output_shape, output_shape) {
460                return Ok(false);
461            }
462        }
463
464        Ok(true)
465    }
466
467    pub fn shapes_broadcastable(&self, shape1: &[usize], shape2: &[usize]) -> bool {
468        let max_len = shape1.len().max(shape2.len());
469
470        for i in 0..max_len {
471            let dim1 = shape1.get(shape1.len().saturating_sub(max_len - i)).copied().unwrap_or(1);
472            let dim2 = shape2.get(shape2.len().saturating_sub(max_len - i)).copied().unwrap_or(1);
473
474            if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
475                return false;
476            }
477        }
478
479        true
480    }
481
482    fn check_data_type_compatibility(&self, nodes: &[&GraphNode]) -> Result<bool> {
483        if nodes.is_empty() {
484            return Ok(true);
485        }
486
487        let first_dtype =
488            &nodes[0].outputs.first().ok_or_else(|| anyhow!("Node has no outputs"))?.dtype;
489
490        for node in nodes.iter().skip(1) {
491            let dtype = &node.outputs.first().ok_or_else(|| anyhow!("Node has no outputs"))?.dtype;
492
493            if dtype != first_dtype {
494                return Ok(false);
495            }
496        }
497
498        Ok(true)
499    }
500
501    fn check_device_compatibility(&self, nodes: &[&GraphNode]) -> Result<bool> {
502        if nodes.is_empty() {
503            return Ok(true);
504        }
505
506        let first_device =
507            &nodes[0].outputs.first().ok_or_else(|| anyhow!("Node has no outputs"))?.device;
508
509        for node in nodes.iter().skip(1) {
510            let device =
511                &node.outputs.first().ok_or_else(|| anyhow!("Node has no outputs"))?.device;
512
513            if device != first_device {
514                return Ok(false);
515            }
516        }
517
518        Ok(true)
519    }
520
521    fn check_contiguity(&self, node_ids: &[String], graph: &ComputationGraph) -> Result<bool> {
522        // Check if nodes are contiguous in the execution order
523        let execution_positions: HashMap<String, usize> = graph
524            .execution_order
525            .iter()
526            .enumerate()
527            .map(|(i, id)| (id.clone(), i))
528            .collect();
529
530        let mut positions: Vec<usize> =
531            node_ids.iter().filter_map(|id| execution_positions.get(id)).copied().collect();
532
533        if positions.len() != node_ids.len() {
534            return Ok(false); // Some nodes not in execution order
535        }
536
537        positions.sort();
538
539        // Check if positions are consecutive
540        for i in 1..positions.len() {
541            if positions[i] != positions[i - 1] + 1 {
542                return Ok(false);
543            }
544        }
545
546        Ok(true)
547    }
548
549    fn estimate_fusion_benefit(
550        &self,
551        node_ids: &[String],
552        graph: &ComputationGraph,
553    ) -> Result<f64> {
554        let db = self
555            .performance_database
556            .read()
557            .expect("performance_database lock should not be poisoned");
558
559        let mut total_individual_cost = 0.0;
560        let mut _total_ops = 0u64;
561
562        for node_id in node_ids {
563            if let Some(node) = graph.get_node(node_id) {
564                if let Some(cost) = db.get_operation_cost(&node.operation) {
565                    let elements = node.outputs.first().map(|t| t.element_count()).unwrap_or(1);
566
567                    total_individual_cost +=
568                        cost.ops_per_element * elements as f64 + cost.launch_overhead_ns as f64;
569                    _total_ops += node.metadata.estimated_ops;
570                }
571            }
572        }
573
574        // Estimate fused cost (reduced launch overhead, better cache utilization)
575        let launch_overhead_reduction = (node_ids.len() - 1) as f64 * 1000.0; // Save 1µs per avoided launch
576        let cache_efficiency_gain = 1.2; // 20% improvement from better cache utilization
577
578        let fused_cost =
579            (total_individual_cost - launch_overhead_reduction) / cache_efficiency_gain;
580
581        let speedup = if fused_cost > 0.0 { total_individual_cost / fused_cost } else { 1.0 };
582
583        Ok(speedup)
584    }
585
586    fn generate_kernel_name(&self, pattern: &FusionPattern) -> String {
587        match pattern {
588            FusionPattern::ElementWiseChain(ops) => {
589                let op_names: Vec<String> =
590                    ops.iter().map(|op| format!("{:?}", op).to_lowercase()).collect();
591                format!("elementwise_{}", op_names.join("_"))
592            },
593            FusionPattern::LinearActivation { activation, .. } => match activation {
594                Some(act) => format!("linear_{:?}", act).to_lowercase(),
595                None => "linear".to_string(),
596            },
597            FusionPattern::AttentionFusion { .. } => "attention_fusion".to_string(),
598            FusionPattern::BatchNorm { .. } => "batch_norm".to_string(),
599            FusionPattern::Custom { name, .. } => name.to_lowercase(),
600            _ => "custom_fusion".to_string(),
601        }
602    }
603
604    fn generate_kernel_implementation(
605        &self,
606        opportunity: &FusionOpportunity,
607    ) -> Result<KernelImplementation> {
608        // For simplicity, generate CPU implementation
609        // In a full implementation, this would choose based on device capabilities
610        self.generate_cpu_kernel(opportunity)
611    }
612
613    fn generate_cpu_kernel(&self, opportunity: &FusionOpportunity) -> Result<KernelImplementation> {
614        let kernel_code = match &opportunity.pattern {
615            FusionPattern::ElementWiseChain(ops) => self.generate_elementwise_cpu_code(ops),
616            FusionPattern::LinearActivation { .. } => self.generate_linear_activation_cpu_code(),
617            _ => "// Generic fused kernel implementation".to_string(),
618        };
619
620        Ok(KernelImplementation::CPU(kernel_code))
621    }
622
623    fn generate_elementwise_cpu_code(&self, ops: &[OperationType]) -> String {
624        let mut code = String::new();
625        code.push_str("void fused_elementwise_kernel(float* input, float* output, int size) {\n");
626        code.push_str("    #pragma omp parallel for\n");
627        code.push_str("    for (int i = 0; i < size; i++) {\n");
628        code.push_str("        float value = input[i];\n");
629
630        for op in ops {
631            match op {
632                OperationType::Add => code.push_str("        value = value + 1.0f; // Simplified\n"),
633                OperationType::ReLU => code.push_str("        value = fmaxf(0.0f, value);\n"),
634                OperationType::GELU => code.push_str("        value = 0.5f * value * (1.0f + tanhf(0.797885f * (value + 0.044715f * value * value * value)));\n"),
635                _ => code.push_str("        // Other operation\n"),
636            }
637        }
638
639        code.push_str("        output[i] = value;\n");
640        code.push_str("    }\n");
641        code.push_str("}\n");
642
643        code
644    }
645
646    fn generate_linear_activation_cpu_code(&self) -> String {
647        r#"
648void fused_linear_activation_kernel(
649    float* input, float* weight, float* bias, float* output,
650    int batch_size, int input_dim, int output_dim
651) {
652    #pragma omp parallel for
653    for (int b = 0; b < batch_size; b++) {
654        for (int o = 0; o < output_dim; o++) {
655            float sum = bias[o];
656            for (int i = 0; i < input_dim; i++) {
657                sum += input[b * input_dim + i] * weight[o * input_dim + i];
658            }
659            // Apply ReLU activation
660            output[b * output_dim + o] = fmaxf(0.0f, sum);
661        }
662    }
663}
664        "#
665        .to_string()
666    }
667
668    /// Calculate memory savings from fusing operations by eliminating intermediate tensors
669    fn calculate_memory_savings(
670        &self,
671        graph: &ComputationGraph,
672        node_ids: &[String],
673    ) -> Result<u64> {
674        let mut total_memory_saved = 0u64;
675
676        // For each node in the fusion (except the last one), calculate memory of intermediate outputs
677        // that will be eliminated by fusion
678        for (i, node_id) in node_ids.iter().enumerate() {
679            // Skip the last node as its output is still needed
680            if i == node_ids.len() - 1 {
681                continue;
682            }
683
684            let node = graph
685                .nodes
686                .get(node_id)
687                .ok_or_else(|| anyhow!("Node {} not found in graph", node_id))?;
688
689            // Calculate memory used by this node's output tensors that will be eliminated
690            for output in &node.outputs {
691                // Only count intermediate tensors that are consumed only by nodes within the fusion
692                if self.is_intermediate_tensor_in_fusion(node_id, output, graph, node_ids)? {
693                    total_memory_saved += output.memory_size() as u64;
694                }
695            }
696        }
697
698        Ok(total_memory_saved)
699    }
700
701    /// Check if a tensor is intermediate (only consumed within the fusion group)
702    fn is_intermediate_tensor_in_fusion(
703        &self,
704        producer_id: &str,
705        _tensor: &TensorInfo,
706        graph: &ComputationGraph,
707        fusion_node_ids: &[String],
708    ) -> Result<bool> {
709        let fusion_set: HashSet<String> = fusion_node_ids.iter().cloned().collect();
710
711        // Find all consumers of this producer node
712        let mut consumers = Vec::new();
713        for (node_id, dependencies) in &graph.edges {
714            if dependencies.contains(&producer_id.to_string()) {
715                consumers.push(node_id);
716            }
717        }
718
719        // If all consumers are within the fusion group, then this is an intermediate tensor
720        Ok(
721            !consumers.is_empty()
722                && consumers.iter().all(|consumer| fusion_set.contains(*consumer)),
723        )
724    }
725
726    fn pattern_name(&self, pattern: &FusionPattern) -> String {
727        match pattern {
728            FusionPattern::ElementWiseChain(_) => "ElementWiseChain".to_string(),
729            FusionPattern::LinearActivation { .. } => "LinearActivation".to_string(),
730            FusionPattern::AttentionFusion { .. } => "AttentionFusion".to_string(),
731            FusionPattern::BatchNorm { .. } => "BatchNorm".to_string(),
732            FusionPattern::Custom { name, .. } => name.clone(),
733            _ => "Unknown".to_string(),
734        }
735    }
736}
737
738impl Default for KernelFusionEngine {
739    fn default() -> Self {
740        Self::new()
741    }
742}