1use std::collections::HashMap;
8use std::sync::{Arc, Mutex, OnceLock};
9
10#[cfg(feature = "gpu")]
11use crate::gpu::{GpuBackend, GpuDevice, GpuError, GpuKernel};
12
13#[derive(Debug, Clone, Hash, PartialEq, Eq)]
15pub struct KernelId {
16 pub module: String,
18 pub operation: String,
20 pub dtype: String,
22 pub variant: Option<String>,
24}
25
26impl KernelId {
27 pub fn new(module: &str, operation: &str, dtype: &str) -> Self {
29 Self {
30 module: module.to_string(),
31 operation: operation.to_string(),
32 dtype: dtype.to_string(),
33 variant: None,
34 }
35 }
36
37 pub fn with_variant(module: &str, operation: &str, dtype: &str, variant: &str) -> Self {
39 Self {
40 module: module.to_string(),
41 operation: operation.to_string(),
42 dtype: dtype.to_string(),
43 variant: Some(variant.to_string()),
44 }
45 }
46
47 pub fn as_kernel_name(&self) -> String {
49 match &self.variant {
50 Some(variant) => format!(
51 "{}_{}_{}__{}",
52 self.module, self.operation, self.dtype, variant
53 ),
54 None => format!(
55 "{module}_{operation}__{dtype}",
56 module = self.module,
57 operation = self.operation,
58 dtype = self.dtype
59 ),
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct KernelSource {
67 pub source: String,
69 pub backend: GpuBackend,
71 pub entry_point: String,
73 pub workgroup_size: (u32, u32, u32),
75 pub shared_memory: usize,
77 pub uses_tensor_cores: bool,
79}
80
81#[cfg(feature = "gpu")]
83struct CompiledKernel {
84 kernel: Arc<GpuKernel>,
85 device_id: usize,
86}
87
88static KERNEL_REGISTRY: OnceLock<Mutex<KernelRegistry>> = OnceLock::new();
90
91pub struct KernelRegistry {
93 sources: HashMap<KernelId, Vec<KernelSource>>,
95 #[cfg(feature = "gpu")]
97 compiled_cache: HashMap<(KernelId, usize), CompiledKernel>,
98}
99
100impl KernelRegistry {
101 fn new() -> Self {
103 Self {
104 sources: HashMap::new(),
105 #[cfg(feature = "gpu")]
106 compiled_cache: HashMap::new(),
107 }
108 }
109
110 pub fn global() -> &'static Mutex<KernelRegistry> {
112 KERNEL_REGISTRY.get_or_init(|| {
113 let mut registry = KernelRegistry::new();
114 registry.register_builtin_kernels();
116 Mutex::new(registry)
117 })
118 }
119
120 fn register_builtin_kernels(&mut self) {
122 self.register_blas_kernels();
124
125 self.register_reduction_kernels();
127
128 self.register_utility_kernels();
130 }
131
132 pub fn register_kernel(&mut self, id: KernelId, source: KernelSource) {
134 self.sources.entry(id).or_default().push(source);
135 }
136
137 pub fn get_sources(&self, id: &KernelId) -> Option<&Vec<KernelSource>> {
139 self.sources.get(id)
140 }
141
142 #[cfg(feature = "gpu")]
144 pub fn get_kernel(
145 &mut self,
146 id: &KernelId,
147 device: &GpuDevice,
148 ) -> Result<Arc<GpuKernel>, GpuError> {
149 let device_id = device.device_id();
150 let cache_key = (id.clone(), device_id);
151
152 if let Some(cached) = self.compiled_cache.get(&cache_key) {
154 if cached.device_id == device_id {
155 return Ok(cached.kernel.clone());
156 }
157 }
158
159 let sources = self
161 .sources
162 .get(id)
163 .ok_or_else(|| GpuError::KernelNotFound(id.as_kernel_name()))?;
164
165 let source = sources
166 .iter()
167 .find(|s| s.backend == device.backend())
168 .ok_or_else(|| GpuError::BackendNotSupported(device.backend()))?;
169
170 let kernel = device.compile_kernel(&source.source, &source.entry_point)?;
172 let kernel = Arc::new(kernel);
173
174 self.compiled_cache.insert(
176 cache_key,
177 CompiledKernel {
178 kernel: kernel.clone(),
179 device_id,
180 },
181 );
182
183 Ok(kernel)
184 }
185
186 #[cfg(feature = "gpu")]
188 pub fn clear_cache(&mut self) {
189 self.compiled_cache.clear();
190 }
191
192 pub fn list_kernels(&self) -> Vec<KernelId> {
194 self.sources.keys().cloned().collect()
195 }
196
197 pub fn has_kernel(&self, id: &KernelId) -> bool {
199 self.sources.contains_key(id)
200 }
201}
202
203impl KernelRegistry {
205 fn register_blas_kernels(&mut self) {
206 self.register_kernel(
208 KernelId::new("core", "gemm", "f32"),
209 KernelSource {
210 source: include_str!("gpu/kernels/gemm_f32.cu").to_string(),
211 backend: GpuBackend::Cuda,
212 entry_point: "gemm_f32".to_string(),
213 workgroup_size: (16, 16, 1),
214 shared_memory: 4096,
215 uses_tensor_cores: false,
216 },
217 );
218
219 self.register_kernel(
221 KernelId::new("core", "gemm", "f64"),
222 KernelSource {
223 source: include_str!("gpu/kernels/gemm_f64.cu").to_string(),
224 backend: GpuBackend::Cuda,
225 entry_point: "gemm_f64".to_string(),
226 workgroup_size: (16, 16, 1),
227 shared_memory: 8192,
228 uses_tensor_cores: false,
229 },
230 );
231
232 self.register_kernel(
234 KernelId::new("core", "axpy", "f32"),
235 KernelSource {
236 source: include_str!("gpu/kernels/axpy.cu").to_string(),
237 backend: GpuBackend::Cuda,
238 entry_point: "axpy_f32".to_string(),
239 workgroup_size: (256, 1, 1),
240 shared_memory: 0,
241 uses_tensor_cores: false,
242 },
243 );
244 }
245
246 fn register_reduction_kernels(&mut self) {
247 self.register_kernel(
249 KernelId::new("core", "reduce_sum", "f32"),
250 KernelSource {
251 source: include_str!("gpu/kernels/reduce_sum.cu").to_string(),
252 backend: GpuBackend::Cuda,
253 entry_point: "reduce_sum_f32".to_string(),
254 workgroup_size: (256, 1, 1),
255 shared_memory: 1024,
256 uses_tensor_cores: false,
257 },
258 );
259
260 self.register_kernel(
262 KernelId::new("core", "reduce_max", "f32"),
263 KernelSource {
264 source: include_str!("gpu/kernels/reduce_max.cu").to_string(),
265 backend: GpuBackend::Cuda,
266 entry_point: "reduce_max_f32".to_string(),
267 workgroup_size: (256, 1, 1),
268 shared_memory: 1024,
269 uses_tensor_cores: false,
270 },
271 );
272 }
273
274 fn register_utility_kernels(&mut self) {
275 self.register_kernel(
277 KernelId::new("core", "memcpy", "f32"),
278 KernelSource {
279 source: include_str!("gpu/kernels/memcpy.cu").to_string(),
280 backend: GpuBackend::Cuda,
281 entry_point: "memcpy_f32".to_string(),
282 workgroup_size: (256, 1, 1),
283 shared_memory: 0,
284 uses_tensor_cores: false,
285 },
286 );
287
288 self.register_kernel(
290 KernelId::new("core", "fill", "f32"),
291 KernelSource {
292 source: include_str!("gpu/kernels/fill.cu").to_string(),
293 backend: GpuBackend::Cuda,
294 entry_point: "fill_f32".to_string(),
295 workgroup_size: (256, 1, 1),
296 shared_memory: 0,
297 uses_tensor_cores: false,
298 },
299 );
300 }
301}
302
303#[allow(dead_code)]
335pub fn register_module_kernel(id: KernelId, source: KernelSource) {
336 let registry = KernelRegistry::global();
337 let mut registry = registry.lock().expect("Operation failed");
338 registry.register_kernel(id, source);
339}
340
341#[cfg(feature = "gpu")]
343#[allow(dead_code)]
344pub fn get_kernel(id: &KernelId, device: &GpuDevice) -> Result<Arc<GpuKernel>, GpuError> {
345 let registry = KernelRegistry::global();
346 let mut registry = registry.lock().expect("Operation failed");
347 registry.get_kernel(id, device)
348}
349
350#[allow(dead_code)]
352pub fn has_kernel(id: &KernelId) -> bool {
353 let registry = KernelRegistry::global();
354 let registry = registry.lock().expect("Operation failed");
355 registry.has_kernel(id)
356}
357
358#[allow(dead_code)]
360pub fn list_kernels() -> Vec<KernelId> {
361 let registry = KernelRegistry::global();
362 let registry = registry.lock().expect("Operation failed");
363 registry.list_kernels()
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn test_kernel_id() {
372 let id = KernelId::new("linalg", "gemm", "f32");
373 assert_eq!(id.as_kernel_name(), "linalg_gemm__f32");
374
375 let id_with_variant = KernelId::with_variant("fft", "fft2d", "c64", "batched");
376 assert_eq!(id_with_variant.as_kernel_name(), "fft_fft2d_c64__batched");
377 }
378
379 #[test]
380 fn test_kernel_registration() {
381 let id = KernelId::new("test", "dummy", "f32");
382 let source = KernelSource {
383 source: "dummy kernel".to_string(),
384 backend: GpuBackend::Cuda,
385 entry_point: "dummy".to_string(),
386 workgroup_size: (1, 1, 1),
387 shared_memory: 0,
388 uses_tensor_cores: false,
389 };
390
391 register_module_kernel(id.clone(), source);
392 assert!(has_kernel(&id));
393 }
394
395 #[test]
396 fn test_builtin_kernels() {
397 assert!(has_kernel(&KernelId::new("core", "gemm", "f32")));
399 assert!(has_kernel(&KernelId::new("core", "reduce_sum", "f32")));
400 assert!(has_kernel(&KernelId::new("core", "fill", "f32")));
401 }
402}