Skip to main content

scirs2_core/
gpu_registry.rs

1//! GPU Kernel Registry System
2//!
3//! This module provides a centralized registry for all GPU kernels used across
4//! the SciRS2 ecosystem. All modules must register their GPU kernels here
5//! instead of implementing them directly.
6
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex, OnceLock};
9
10#[cfg(feature = "gpu")]
11use crate::gpu::{GpuBackend, GpuDevice, GpuError, GpuKernel};
12
13/// GPU kernel identifier
14#[derive(Debug, Clone, Hash, PartialEq, Eq)]
15pub struct KernelId {
16    /// Module that owns this kernel (e.g., "linalg", "fft", "neural")
17    pub module: String,
18    /// Operation name (e.g., "gemm", "fft2d", "conv2d")
19    pub operation: String,
20    /// Data type (e.g., "f32", "f64", "c32", "c64")
21    pub dtype: String,
22    /// Optional variant (e.g., "transposed", "batched", "strided")
23    pub variant: Option<String>,
24}
25
26impl KernelId {
27    /// Create a new kernel identifier
28    pub fn new(module: &str, operation: &str, dtype: &str) -> Self {
29        Self {
30            module: module.to_string(),
31            operation: operation.to_string(),
32            dtype: dtype.to_string(),
33            variant: None,
34        }
35    }
36
37    /// Create a new kernel identifier with a variant
38    pub fn with_variant(module: &str, operation: &str, dtype: &str, variant: &str) -> Self {
39        Self {
40            module: module.to_string(),
41            operation: operation.to_string(),
42            dtype: dtype.to_string(),
43            variant: Some(variant.to_string()),
44        }
45    }
46
47    /// Get a string representation suitable for kernel naming
48    pub fn as_kernel_name(&self) -> String {
49        match &self.variant {
50            Some(variant) => format!(
51                "{}_{}_{}__{}",
52                self.module, self.operation, self.dtype, variant
53            ),
54            None => format!(
55                "{module}_{operation}__{dtype}",
56                module = self.module,
57                operation = self.operation,
58                dtype = self.dtype
59            ),
60        }
61    }
62}
63
64/// GPU kernel source code and metadata
65#[derive(Debug, Clone)]
66pub struct KernelSource {
67    /// The actual kernel source code
68    pub source: String,
69    /// The backend this kernel is written for
70    pub backend: GpuBackend,
71    /// Entry point function name
72    pub entry_point: String,
73    /// Required workgroup/block size (x, y, z)
74    pub workgroup_size: (u32, u32, u32),
75    /// Shared memory requirements in bytes
76    pub shared_memory: usize,
77    /// Whether this kernel uses tensor cores or similar
78    pub uses_tensor_cores: bool,
79}
80
81/// Compiled kernel cache entry
82#[cfg(feature = "gpu")]
83struct CompiledKernel {
84    kernel: Arc<GpuKernel>,
85    device_id: usize,
86}
87
88/// Global GPU kernel registry
89static KERNEL_REGISTRY: OnceLock<Mutex<KernelRegistry>> = OnceLock::new();
90
91/// GPU kernel registry
92pub struct KernelRegistry {
93    /// Registered kernel sources
94    sources: HashMap<KernelId, Vec<KernelSource>>,
95    /// Compiled kernel cache
96    #[cfg(feature = "gpu")]
97    compiled_cache: HashMap<(KernelId, usize), CompiledKernel>,
98}
99
100impl KernelRegistry {
101    /// Create a new kernel registry
102    fn new() -> Self {
103        Self {
104            sources: HashMap::new(),
105            #[cfg(feature = "gpu")]
106            compiled_cache: HashMap::new(),
107        }
108    }
109
110    /// Get the global kernel registry
111    pub fn global() -> &'static Mutex<KernelRegistry> {
112        KERNEL_REGISTRY.get_or_init(|| {
113            let mut registry = KernelRegistry::new();
114            // Register built-in kernels
115            registry.register_builtin_kernels();
116            Mutex::new(registry)
117        })
118    }
119
120    /// Register built-in kernels from scirs2-core
121    fn register_builtin_kernels(&mut self) {
122        // Register BLAS kernels
123        self.register_blas_kernels();
124
125        // Register reduction kernels
126        self.register_reduction_kernels();
127
128        // Register utility kernels
129        self.register_utility_kernels();
130    }
131
132    /// Register a kernel source
133    pub fn register_kernel(&mut self, id: KernelId, source: KernelSource) {
134        self.sources.entry(id).or_default().push(source);
135    }
136
137    /// Get kernel sources for a given ID
138    pub fn get_sources(&self, id: &KernelId) -> Option<&Vec<KernelSource>> {
139        self.sources.get(id)
140    }
141
142    /// Get a compiled kernel for the current device
143    #[cfg(feature = "gpu")]
144    pub fn get_kernel(
145        &mut self,
146        id: &KernelId,
147        device: &GpuDevice,
148    ) -> Result<Arc<GpuKernel>, GpuError> {
149        let device_id = device.device_id();
150        let cache_key = (id.clone(), device_id);
151
152        // Check cache first
153        if let Some(cached) = self.compiled_cache.get(&cache_key) {
154            if cached.device_id == device_id {
155                return Ok(cached.kernel.clone());
156            }
157        }
158
159        // Find appropriate source for the device's backend
160        let sources = self
161            .sources
162            .get(id)
163            .ok_or_else(|| GpuError::KernelNotFound(id.as_kernel_name()))?;
164
165        let source = sources
166            .iter()
167            .find(|s| s.backend == device.backend())
168            .ok_or_else(|| GpuError::BackendNotSupported(device.backend()))?;
169
170        // Compile the kernel
171        let kernel = device.compile_kernel(&source.source, &source.entry_point)?;
172        let kernel = Arc::new(kernel);
173
174        // Cache the compiled kernel
175        self.compiled_cache.insert(
176            cache_key,
177            CompiledKernel {
178                kernel: kernel.clone(),
179                device_id,
180            },
181        );
182
183        Ok(kernel)
184    }
185
186    /// Clear the compiled kernel cache
187    #[cfg(feature = "gpu")]
188    pub fn clear_cache(&mut self) {
189        self.compiled_cache.clear();
190    }
191
192    /// List all registered kernels
193    pub fn list_kernels(&self) -> Vec<KernelId> {
194        self.sources.keys().cloned().collect()
195    }
196
197    /// Check if a kernel is registered
198    pub fn has_kernel(&self, id: &KernelId) -> bool {
199        self.sources.contains_key(id)
200    }
201}
202
203// Built-in kernel registration implementations
204impl KernelRegistry {
205    fn register_blas_kernels(&mut self) {
206        // GEMM kernels for f32
207        self.register_kernel(
208            KernelId::new("core", "gemm", "f32"),
209            KernelSource {
210                source: include_str!("gpu/kernels/gemm_f32.cu").to_string(),
211                backend: GpuBackend::Cuda,
212                entry_point: "gemm_f32".to_string(),
213                workgroup_size: (16, 16, 1),
214                shared_memory: 4096,
215                uses_tensor_cores: false,
216            },
217        );
218
219        // GEMM kernels for f64
220        self.register_kernel(
221            KernelId::new("core", "gemm", "f64"),
222            KernelSource {
223                source: include_str!("gpu/kernels/gemm_f64.cu").to_string(),
224                backend: GpuBackend::Cuda,
225                entry_point: "gemm_f64".to_string(),
226                workgroup_size: (16, 16, 1),
227                shared_memory: 8192,
228                uses_tensor_cores: false,
229            },
230        );
231
232        // AXPY kernels
233        self.register_kernel(
234            KernelId::new("core", "axpy", "f32"),
235            KernelSource {
236                source: include_str!("gpu/kernels/axpy.cu").to_string(),
237                backend: GpuBackend::Cuda,
238                entry_point: "axpy_f32".to_string(),
239                workgroup_size: (256, 1, 1),
240                shared_memory: 0,
241                uses_tensor_cores: false,
242            },
243        );
244    }
245
246    fn register_reduction_kernels(&mut self) {
247        // Sum reduction kernels
248        self.register_kernel(
249            KernelId::new("core", "reduce_sum", "f32"),
250            KernelSource {
251                source: include_str!("gpu/kernels/reduce_sum.cu").to_string(),
252                backend: GpuBackend::Cuda,
253                entry_point: "reduce_sum_f32".to_string(),
254                workgroup_size: (256, 1, 1),
255                shared_memory: 1024,
256                uses_tensor_cores: false,
257            },
258        );
259
260        // Max reduction kernels
261        self.register_kernel(
262            KernelId::new("core", "reduce_max", "f32"),
263            KernelSource {
264                source: include_str!("gpu/kernels/reduce_max.cu").to_string(),
265                backend: GpuBackend::Cuda,
266                entry_point: "reduce_max_f32".to_string(),
267                workgroup_size: (256, 1, 1),
268                shared_memory: 1024,
269                uses_tensor_cores: false,
270            },
271        );
272    }
273
274    fn register_utility_kernels(&mut self) {
275        // Memory copy kernels
276        self.register_kernel(
277            KernelId::new("core", "memcpy", "f32"),
278            KernelSource {
279                source: include_str!("gpu/kernels/memcpy.cu").to_string(),
280                backend: GpuBackend::Cuda,
281                entry_point: "memcpy_f32".to_string(),
282                workgroup_size: (256, 1, 1),
283                shared_memory: 0,
284                uses_tensor_cores: false,
285            },
286        );
287
288        // Fill kernels
289        self.register_kernel(
290            KernelId::new("core", "fill", "f32"),
291            KernelSource {
292                source: include_str!("gpu/kernels/fill.cu").to_string(),
293                backend: GpuBackend::Cuda,
294                entry_point: "fill_f32".to_string(),
295                workgroup_size: (256, 1, 1),
296                shared_memory: 0,
297                uses_tensor_cores: false,
298            },
299        );
300    }
301}
302
303/// Register a kernel from a module
304///
305/// This should be called during module initialization to register
306/// module-specific kernels with the global registry.
307///
308/// # Example
309///
310/// ```rust
311/// use scirs2_core::gpu_registry::{register_module_kernel, KernelId, KernelSource};
312/// use scirs2_core::gpu::GpuBackend;
313///
314/// fn register_fft_kernels() {
315///     let fft_kernel_source = r#"
316///         extern "C" __global__ void fft2d_c32(float2* data, int n) {
317///             // FFT kernel implementation
318///         }
319///     "#;
320///
321///     register_module_kernel(
322///         KernelId::new("fft", "fft2d", "c32"),
323///         KernelSource {
324///             source: fft_kernel_source.to_string(),
325///             backend: GpuBackend::Cuda,
326///             entry_point: "fft2d_c32".to_string(),
327///             workgroup_size: (32, 8, 1),
328///             shared_memory: 16384,
329///             uses_tensor_cores: false,
330///         },
331///     );
332/// }
333/// ```
334#[allow(dead_code)]
335pub fn register_module_kernel(id: KernelId, source: KernelSource) {
336    let registry = KernelRegistry::global();
337    let mut registry = registry.lock().expect("Operation failed");
338    registry.register_kernel(id, source);
339}
340
341/// Get a compiled kernel for the current device
342#[cfg(feature = "gpu")]
343#[allow(dead_code)]
344pub fn get_kernel(id: &KernelId, device: &GpuDevice) -> Result<Arc<GpuKernel>, GpuError> {
345    let registry = KernelRegistry::global();
346    let mut registry = registry.lock().expect("Operation failed");
347    registry.get_kernel(id, device)
348}
349
350/// Check if a kernel is registered
351#[allow(dead_code)]
352pub fn has_kernel(id: &KernelId) -> bool {
353    let registry = KernelRegistry::global();
354    let registry = registry.lock().expect("Operation failed");
355    registry.has_kernel(id)
356}
357
358/// List all registered kernels
359#[allow(dead_code)]
360pub fn list_kernels() -> Vec<KernelId> {
361    let registry = KernelRegistry::global();
362    let registry = registry.lock().expect("Operation failed");
363    registry.list_kernels()
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn test_kernel_id() {
372        let id = KernelId::new("linalg", "gemm", "f32");
373        assert_eq!(id.as_kernel_name(), "linalg_gemm__f32");
374
375        let id_with_variant = KernelId::with_variant("fft", "fft2d", "c64", "batched");
376        assert_eq!(id_with_variant.as_kernel_name(), "fft_fft2d_c64__batched");
377    }
378
379    #[test]
380    fn test_kernel_registration() {
381        let id = KernelId::new("test", "dummy", "f32");
382        let source = KernelSource {
383            source: "dummy kernel".to_string(),
384            backend: GpuBackend::Cuda,
385            entry_point: "dummy".to_string(),
386            workgroup_size: (1, 1, 1),
387            shared_memory: 0,
388            uses_tensor_cores: false,
389        };
390
391        register_module_kernel(id.clone(), source);
392        assert!(has_kernel(&id));
393    }
394
395    #[test]
396    fn test_builtin_kernels() {
397        // Check that some built-in kernels are registered
398        assert!(has_kernel(&KernelId::new("core", "gemm", "f32")));
399        assert!(has_kernel(&KernelId::new("core", "reduce_sum", "f32")));
400        assert!(has_kernel(&KernelId::new("core", "fill", "f32")));
401    }
402}