Skip to main content

trustformers_core/kernel_fusion/
kernel.rs

1//! Fused kernel representation and implementation
2//!
3//! This module defines structures for representing fused kernels and their
4//! various backend implementations.
5
6use crate::kernel_fusion::graph::TensorInfo;
7use crate::kernel_fusion::operation_types::FusionPattern;
8
9/// Fused kernel representation
10#[derive(Debug, Clone)]
11pub struct FusedKernel {
12    pub id: String,
13    pub name: String,
14    pub pattern: FusionPattern,
15    pub operations: Vec<String>, // Original operation IDs
16    pub inputs: Vec<TensorInfo>,
17    pub outputs: Vec<TensorInfo>,
18    pub estimated_speedup: f64,
19    pub memory_savings: usize,
20    pub implementation: KernelImplementation,
21}
22
23#[derive(Debug, Clone)]
24pub enum KernelImplementation {
25    CUDA(String),   // CUDA kernel code
26    ROCm(String),   // ROCm/HIP kernel code
27    OpenCL(String), // OpenCL kernel code
28    CPU(String),    // CPU implementation
29    Vulkan(String), // Vulkan compute shader
30    Metal(String),  // Metal compute shader
31    WebGPU(String), // WebGPU shader
32    SIMD(String),   // SIMD intrinsics
33    ASIC(String),   // ASIC-specific kernel code
34}
35
36impl FusedKernel {
37    pub fn new(id: String, name: String, pattern: FusionPattern, operations: Vec<String>) -> Self {
38        Self {
39            id,
40            name,
41            pattern,
42            operations,
43            inputs: Vec::new(),
44            outputs: Vec::new(),
45            estimated_speedup: 1.0,
46            memory_savings: 0,
47            implementation: KernelImplementation::CPU("".to_string()),
48        }
49    }
50
51    pub fn with_inputs(mut self, inputs: Vec<TensorInfo>) -> Self {
52        self.inputs = inputs;
53        self
54    }
55
56    pub fn with_outputs(mut self, outputs: Vec<TensorInfo>) -> Self {
57        self.outputs = outputs;
58        self
59    }
60
61    pub fn with_speedup(mut self, speedup: f64) -> Self {
62        self.estimated_speedup = speedup;
63        self
64    }
65
66    pub fn with_memory_savings(mut self, savings: usize) -> Self {
67        self.memory_savings = savings;
68        self
69    }
70
71    pub fn with_implementation(mut self, implementation: KernelImplementation) -> Self {
72        self.implementation = implementation;
73        self
74    }
75}
76
77impl KernelImplementation {
78    pub fn platform(&self) -> &'static str {
79        match self {
80            KernelImplementation::CUDA(_) => "CUDA",
81            KernelImplementation::ROCm(_) => "ROCm",
82            KernelImplementation::OpenCL(_) => "OpenCL",
83            KernelImplementation::CPU(_) => "CPU",
84            KernelImplementation::Vulkan(_) => "Vulkan",
85            KernelImplementation::Metal(_) => "Metal",
86            KernelImplementation::WebGPU(_) => "WebGPU",
87            KernelImplementation::SIMD(_) => "SIMD",
88            KernelImplementation::ASIC(_) => "ASIC",
89        }
90    }
91
92    pub fn code(&self) -> &str {
93        match self {
94            KernelImplementation::CUDA(code)
95            | KernelImplementation::ROCm(code)
96            | KernelImplementation::OpenCL(code)
97            | KernelImplementation::CPU(code)
98            | KernelImplementation::Vulkan(code)
99            | KernelImplementation::Metal(code)
100            | KernelImplementation::WebGPU(code)
101            | KernelImplementation::SIMD(code)
102            | KernelImplementation::ASIC(code) => code,
103        }
104    }
105}