trustformers_core/kernel_fusion/
kernel.rs1use crate::kernel_fusion::graph::TensorInfo;
7use crate::kernel_fusion::operation_types::FusionPattern;
8
9#[derive(Debug, Clone)]
11pub struct FusedKernel {
12 pub id: String,
13 pub name: String,
14 pub pattern: FusionPattern,
15 pub operations: Vec<String>, pub inputs: Vec<TensorInfo>,
17 pub outputs: Vec<TensorInfo>,
18 pub estimated_speedup: f64,
19 pub memory_savings: usize,
20 pub implementation: KernelImplementation,
21}
22
23#[derive(Debug, Clone)]
24pub enum KernelImplementation {
25 CUDA(String), ROCm(String), OpenCL(String), CPU(String), Vulkan(String), Metal(String), WebGPU(String), SIMD(String), ASIC(String), }
35
36impl FusedKernel {
37 pub fn new(id: String, name: String, pattern: FusionPattern, operations: Vec<String>) -> Self {
38 Self {
39 id,
40 name,
41 pattern,
42 operations,
43 inputs: Vec::new(),
44 outputs: Vec::new(),
45 estimated_speedup: 1.0,
46 memory_savings: 0,
47 implementation: KernelImplementation::CPU("".to_string()),
48 }
49 }
50
51 pub fn with_inputs(mut self, inputs: Vec<TensorInfo>) -> Self {
52 self.inputs = inputs;
53 self
54 }
55
56 pub fn with_outputs(mut self, outputs: Vec<TensorInfo>) -> Self {
57 self.outputs = outputs;
58 self
59 }
60
61 pub fn with_speedup(mut self, speedup: f64) -> Self {
62 self.estimated_speedup = speedup;
63 self
64 }
65
66 pub fn with_memory_savings(mut self, savings: usize) -> Self {
67 self.memory_savings = savings;
68 self
69 }
70
71 pub fn with_implementation(mut self, implementation: KernelImplementation) -> Self {
72 self.implementation = implementation;
73 self
74 }
75}
76
77impl KernelImplementation {
78 pub fn platform(&self) -> &'static str {
79 match self {
80 KernelImplementation::CUDA(_) => "CUDA",
81 KernelImplementation::ROCm(_) => "ROCm",
82 KernelImplementation::OpenCL(_) => "OpenCL",
83 KernelImplementation::CPU(_) => "CPU",
84 KernelImplementation::Vulkan(_) => "Vulkan",
85 KernelImplementation::Metal(_) => "Metal",
86 KernelImplementation::WebGPU(_) => "WebGPU",
87 KernelImplementation::SIMD(_) => "SIMD",
88 KernelImplementation::ASIC(_) => "ASIC",
89 }
90 }
91
92 pub fn code(&self) -> &str {
93 match self {
94 KernelImplementation::CUDA(code)
95 | KernelImplementation::ROCm(code)
96 | KernelImplementation::OpenCL(code)
97 | KernelImplementation::CPU(code)
98 | KernelImplementation::Vulkan(code)
99 | KernelImplementation::Metal(code)
100 | KernelImplementation::WebGPU(code)
101 | KernelImplementation::SIMD(code)
102 | KernelImplementation::ASIC(code) => code,
103 }
104 }
105}