scirs2_core/gpu/kernels/
mod.rs

1//! GPU kernel library for common scientific computing operations
2//!
3//! This module provides optimized GPU kernels for various operations used in
4//! scientific computing, with support for multiple GPU backends.
5
6use std::collections::HashMap;
7use std::fmt;
8
9pub mod blas;
10pub mod complex;
11pub mod elementwise;
12pub mod ml;
13pub mod reduction;
14pub mod transform;
15
16use crate::gpu::{GpuBackend, GpuError};
17
18/// Supported data types for GPU kernels
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum DataType {
21    /// 32-bit floating point (f32)
22    Float32,
23    /// 64-bit floating point (f64)
24    Float64,
25    /// 32-bit signed integer (i32)
26    Int32,
27    /// 32-bit unsigned integer (u32)
28    UInt32,
29    /// 16-bit floating point (f16)
30    Float16,
31    /// Brain floating point (bfloat16)
32    BFloat16,
33}
34
35impl fmt::Display for DataType {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        match self {
38            DataType::Float32 => write!(f, "f32"),
39            DataType::Float64 => write!(f, "f64"),
40            DataType::Int32 => write!(f, "i32"),
41            DataType::UInt32 => write!(f, "u32"),
42            DataType::Float16 => write!(f, "f16"),
43            DataType::BFloat16 => write!(f, "bf16"),
44        }
45    }
46}
47
48/// The type of operation performed by the kernel
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum OperationType {
51    /// Primarily compute-intensive operations
52    ComputeIntensive,
53    /// Primarily memory-intensive operations
54    MemoryIntensive,
55    /// Balanced between compute and memory
56    Balanced,
57}
58
59/// Metadata for kernel execution
60#[derive(Debug, Clone)]
61pub struct KernelMetadata {
62    /// Recommended workgroup size
63    pub workgroup_size: [u32; 3],
64    /// Local memory usage in bytes
65    pub local_memory_usage: usize,
66    /// Whether the kernel supports tensor cores (NVIDIA) or similar
67    pub supports_tensor_cores: bool,
68    /// Operation type (compute intensive, memory intensive, balanced)
69    pub operationtype: OperationType,
70    /// Additional backend-specific metadata
71    pub backend_metadata: HashMap<String, String>,
72}
73
74impl Default for KernelMetadata {
75    fn default() -> Self {
76        Self {
77            workgroup_size: [16, 16, 1],
78            local_memory_usage: 0,
79            supports_tensor_cores: false,
80            operationtype: OperationType::Balanced,
81            backend_metadata: HashMap::new(),
82        }
83    }
84}
85
86/// Parameters for kernel specialization
87#[derive(Debug, Clone)]
88pub struct KernelParams {
89    /// Numeric type (f32, f64, etc.)
90    pub datatype: DataType,
91    /// Input dimensions
92    pub input_dims: Vec<usize>,
93    /// Output dimensions
94    pub output_dims: Vec<usize>,
95    /// Additional numeric parameters
96    pub numeric_params: HashMap<String, f64>,
97    /// Additional string parameters
98    pub string_params: HashMap<String, String>,
99}
100
101impl KernelParams {
102    /// Create new kernel parameters
103    pub fn new(datatype: DataType) -> Self {
104        Self {
105            datatype,
106            input_dims: Vec::new(),
107            output_dims: Vec::new(),
108            numeric_params: HashMap::new(),
109            string_params: HashMap::new(),
110        }
111    }
112
113    /// Set input dimensions
114    pub fn with_input_dims(mut self, dims: Vec<usize>) -> Self {
115        self.input_dims = dims;
116        self
117    }
118
119    /// Set output dimensions
120    pub fn with_output_dims(mut self, dims: Vec<usize>) -> Self {
121        self.output_dims = dims;
122        self
123    }
124
125    /// Add a numeric parameter
126    pub fn with_numeric_param(mut self, name: &str, value: f64) -> Self {
127        self.numeric_params.insert(name.to_string(), value);
128        self
129    }
130
131    /// Add a string parameter
132    pub fn with_string_param(mut self, name: &str, value: &str) -> Self {
133        self.string_params
134            .insert(name.to_string(), value.to_string());
135        self
136    }
137}
138
139/// GPU Kernel interface
140pub trait GpuKernel: Send + Sync {
141    /// The name of the kernel
142    fn name(&self) -> &str;
143
144    /// Get kernel source for the specified backend
145    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError>;
146
147    /// Get kernel metadata (workgroup size, memory requirements, etc.)
148    fn metadata(&self) -> KernelMetadata;
149
150    /// Can this kernel be specialized for the given parameters?
151    fn can_specialize(&self, params: &KernelParams) -> bool;
152
153    /// Create a specialized version of this kernel for the given parameters
154    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError>;
155}
156
157/// Base kernel implementation that can be used by specialized kernels
158pub struct BaseKernel {
159    name: String,
160    cuda_source: String,
161    rocm_source: String,
162    wgpu_source: String,
163    metal_source: String,
164    opencl_source: String,
165    metadata: KernelMetadata,
166}
167
168impl BaseKernel {
169    /// Create a new base kernel
170    pub fn new(
171        name: &str,
172        cuda_source: &str,
173        rocm_source: &str,
174        wgpu_source: &str,
175        metal_source: &str,
176        opencl_source: &str,
177        metadata: KernelMetadata,
178    ) -> Self {
179        Self {
180            name: name.to_string(),
181            cuda_source: cuda_source.to_string(),
182            rocm_source: rocm_source.to_string(),
183            wgpu_source: wgpu_source.to_string(),
184            metal_source: metal_source.to_string(),
185            opencl_source: opencl_source.to_string(),
186            metadata,
187        }
188    }
189}
190
191impl GpuKernel for BaseKernel {
192    fn name(&self) -> &str {
193        &self.name
194    }
195
196    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
197        match backend {
198            GpuBackend::Cuda => Ok(self.cuda_source.clone()),
199            GpuBackend::Rocm => Ok(self.rocm_source.clone()),
200            GpuBackend::Wgpu => Ok(self.wgpu_source.clone()),
201            GpuBackend::Metal => Ok(self.metal_source.clone()),
202            GpuBackend::OpenCL => Ok(self.opencl_source.clone()),
203            _ => Err(GpuError::UnsupportedBackend(backend)),
204        }
205    }
206
207    fn metadata(&self) -> KernelMetadata {
208        self.metadata.clone()
209    }
210
211    fn can_specialize(&self, params: &KernelParams) -> bool {
212        false // Base implementation doesn't support specialization
213    }
214
215    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
216        Err(GpuError::SpecializationNotSupported)
217    }
218}
219
220/// Registry of available GPU kernels
221pub struct KernelRegistry {
222    kernels: HashMap<String, Box<dyn GpuKernel>>,
223}
224
225impl KernelRegistry {
226    /// Create a new kernel registry
227    pub fn new() -> Self {
228        Self {
229            kernels: HashMap::new(),
230        }
231    }
232
233    /// Create a registry with all default kernels
234    pub fn with_default_kernels() -> Self {
235        let mut registry = Self::new();
236
237        // Register BLAS kernels
238        registry.register(Box::new(blas::gemm::GemmKernel::new()));
239        registry.register(Box::new(blas::axpy::AxpyKernel::new()));
240        registry.register(Box::new(blas::gemv::GemvKernel::new()));
241
242        // Register elementwise kernels
243        registry.register(Box::new(elementwise::ElementwiseAddKernel::new()));
244        registry.register(Box::new(elementwise::ElementwiseSubKernel::new()));
245        registry.register(Box::new(elementwise::ElementwiseMulKernel::new()));
246        registry.register(Box::new(elementwise::ElementwiseDivKernel::new()));
247        registry.register(Box::new(elementwise::ElementwisePowKernel::new()));
248        registry.register(Box::new(elementwise::ElementwiseSqrtKernel::new()));
249        registry.register(Box::new(elementwise::ElementwiseExpKernel::new()));
250        registry.register(Box::new(elementwise::ElementwiseLogKernel::new()));
251
252        // Register optimization kernels
253        registry.register(Box::new(create_adam_optimizer_kernel()));
254        registry.register(Box::new(create_sgd_optimizer_kernel()));
255        registry.register(Box::new(create_rmsprop_optimizer_kernel()));
256        registry.register(Box::new(create_adagrad_optimizer_kernel()));
257        registry.register(Box::new(create_lamb_optimizer_kernel()));
258
259        // Register utility kernels
260        registry.register(Box::new(create_memcpy_kernel()));
261        registry.register(Box::new(create_fill_kernel()));
262        registry.register(Box::new(create_reduce_sum_kernel()));
263        registry.register(Box::new(create_reduce_max_kernel()));
264
265        // Register transform kernels
266        registry.register(Box::new(transform::fft::FftKernel::new()));
267        registry.register(Box::new(transform::convolution::Conv1dKernel::new()));
268        registry.register(Box::new(transform::convolution::Conv2dKernel::new()));
269
270        // Register reduction kernels
271        registry.register(Box::new(reduction::sum::SumKernel::new()));
272        registry.register(Box::new(reduction::norm::NormKernel::new()));
273        registry.register(Box::new(reduction::min_max::MinKernel::new()));
274        registry.register(Box::new(reduction::min_max::MaxKernel::new()));
275        registry.register(Box::new(reduction::mean::MeanKernel::new()));
276        registry.register(Box::new(reduction::std_dev::StdDevKernel::new()));
277
278        // Register ML kernels
279        registry.register(Box::new(ml::activation::ReluKernel::new()));
280        registry.register(Box::new(ml::activation::SigmoidKernel::new()));
281        registry.register(Box::new(ml::activation::TanhKernel::new()));
282        registry.register(Box::new(ml::softmax::SoftmaxKernel::new()));
283        registry.register(Box::new(ml::pooling::MaxPoolKernel::new()));
284        registry.register(Box::new(ml::pooling::AvgPoolKernel::new()));
285
286        // Register complex number kernels
287        registry.register(Box::new(complex::ComplexMultiplyKernel::new()));
288        registry.register(Box::new(complex::ComplexConjugateKernel::new()));
289        registry.register(Box::new(complex::ComplexMatMulKernel::new()));
290
291        // Register RK4 integration kernels for advanced mode
292        registry.register(Box::new(create_rk4_stage1_kernel()));
293        registry.register(Box::new(create_rk4_stage2_kernel()));
294        registry.register(Box::new(create_rk4_stage3_kernel()));
295        registry.register(Box::new(create_rk4_stage4_kernel()));
296        registry.register(Box::new(create_rk4_combine_kernel()));
297        registry.register(Box::new(createerror_estimate_kernel()));
298
299        registry
300    }
301
302    /// Register a kernel
303    pub fn register(&mut self, kernel: Box<dyn GpuKernel>) {
304        self.kernels.insert(kernel.name().to_string(), kernel);
305    }
306
307    /// Get a kernel by name
308    pub fn get(&self, name: &str) -> Option<&dyn GpuKernel> {
309        self.kernels.get(name).map(|k| k.as_ref())
310    }
311
312    /// Get a specialized kernel
313    pub fn get_specialized(
314        &self,
315        name: &str,
316        params: &KernelParams,
317    ) -> Result<Box<dyn GpuKernel>, GpuError> {
318        let kernel = self
319            .get(name)
320            .ok_or_else(|| GpuError::KernelNotFound(name.to_string()))?;
321
322        if kernel.can_specialize(params) {
323            kernel.specialize(params)
324        } else {
325            Err(GpuError::SpecializationNotSupported)
326        }
327    }
328}
329
330impl Default for KernelRegistry {
331    fn default() -> Self {
332        Self::with_default_kernels()
333    }
334}
335
336/// Create RK4 Stage 1 kernel for advanced mode GPU acceleration
337#[allow(dead_code)]
338fn create_rk4_stage1_kernel() -> BaseKernel {
339    let cuda_source = include_str!("rk4_stage1.cu");
340    let metadata = KernelMetadata {
341        workgroup_size: [256, 1, 1],
342        local_memory_usage: 0,
343        supports_tensor_cores: false,
344        operationtype: OperationType::ComputeIntensive,
345        backend_metadata: HashMap::new(),
346    };
347
348    BaseKernel::new(
349        "rk4_stage1",
350        cuda_source,
351        cuda_source, // Use CUDA source for ROCm (HIP compatible)
352        "",          // WGPU source not implemented yet
353        "",          // Metal source not implemented yet
354        cuda_source, // Use CUDA source for OpenCL (with minor modifications)
355        metadata,
356    )
357}
358
359/// Create RK4 Stage 2 kernel for advanced mode GPU acceleration
360#[allow(dead_code)]
361fn create_rk4_stage2_kernel() -> BaseKernel {
362    let cuda_source = include_str!("rk4_stage2.cu");
363    let metadata = KernelMetadata {
364        workgroup_size: [256, 1, 1],
365        local_memory_usage: 0,
366        supports_tensor_cores: false,
367        operationtype: OperationType::ComputeIntensive,
368        backend_metadata: HashMap::new(),
369    };
370
371    BaseKernel::new(
372        "rk4_stage2",
373        cuda_source,
374        cuda_source,
375        "",
376        "",
377        cuda_source,
378        metadata,
379    )
380}
381
382/// Create RK4 Stage 3 kernel for advanced mode GPU acceleration
383#[allow(dead_code)]
384fn create_rk4_stage3_kernel() -> BaseKernel {
385    let cuda_source = include_str!("rk4_stage3.cu");
386    let metadata = KernelMetadata {
387        workgroup_size: [256, 1, 1],
388        local_memory_usage: 0,
389        supports_tensor_cores: false,
390        operationtype: OperationType::ComputeIntensive,
391        backend_metadata: HashMap::new(),
392    };
393
394    BaseKernel::new(
395        "rk4_stage3",
396        cuda_source,
397        cuda_source,
398        "",
399        "",
400        cuda_source,
401        metadata,
402    )
403}
404
405/// Create RK4 Stage 4 kernel for advanced mode GPU acceleration
406#[allow(dead_code)]
407fn create_rk4_stage4_kernel() -> BaseKernel {
408    let cuda_source = include_str!("rk4_stage4.cu");
409    let metadata = KernelMetadata {
410        workgroup_size: [256, 1, 1],
411        local_memory_usage: 0,
412        supports_tensor_cores: false,
413        operationtype: OperationType::ComputeIntensive,
414        backend_metadata: HashMap::new(),
415    };
416
417    BaseKernel::new(
418        "rk4_stage4",
419        cuda_source,
420        cuda_source,
421        "",
422        "",
423        cuda_source,
424        metadata,
425    )
426}
427
428/// Create RK4 Combination kernel for advanced mode GPU acceleration
429#[allow(dead_code)]
430fn create_rk4_combine_kernel() -> BaseKernel {
431    let cuda_source = include_str!("rk4_combine.cu");
432    let metadata = KernelMetadata {
433        workgroup_size: [256, 1, 1],
434        local_memory_usage: 0,
435        supports_tensor_cores: false,
436        operationtype: OperationType::MemoryIntensive,
437        backend_metadata: HashMap::new(),
438    };
439
440    BaseKernel::new(
441        "rk4_combine",
442        cuda_source,
443        cuda_source,
444        "",
445        "",
446        cuda_source,
447        metadata,
448    )
449}
450
451/// Create Error Estimation kernel for adaptive step size control
452#[allow(dead_code)]
453fn createerror_estimate_kernel() -> BaseKernel {
454    let cuda_source = include_str!("error_estimate.cu");
455    let metadata = KernelMetadata {
456        workgroup_size: [256, 1, 1],
457        local_memory_usage: 1024, // Shared memory for reduction
458        supports_tensor_cores: false,
459        operationtype: OperationType::ComputeIntensive,
460        backend_metadata: HashMap::new(),
461    };
462
463    BaseKernel::new(
464        "error_estimate",
465        cuda_source,
466        cuda_source,
467        "",
468        "",
469        cuda_source,
470        metadata,
471    )
472}
473
474/// Create Adam optimizer kernel for GPU acceleration
475#[allow(dead_code)]
476fn create_adam_optimizer_kernel() -> BaseKernel {
477    let cuda_source = include_str!("adam_optimizer.cu");
478    let metadata = KernelMetadata {
479        workgroup_size: [256, 1, 1],
480        local_memory_usage: 0,
481        supports_tensor_cores: false,
482        operationtype: OperationType::ComputeIntensive,
483        backend_metadata: HashMap::new(),
484    };
485
486    BaseKernel::new(
487        "adam_optimizer",
488        cuda_source,
489        cuda_source,
490        "",
491        "",
492        cuda_source,
493        metadata,
494    )
495}
496
497/// Create SGD optimizer kernel for GPU acceleration
498#[allow(dead_code)]
499fn create_sgd_optimizer_kernel() -> BaseKernel {
500    let cuda_source = include_str!("sgd_optimizer.cu");
501    let metadata = KernelMetadata {
502        workgroup_size: [256, 1, 1],
503        local_memory_usage: 0,
504        supports_tensor_cores: false,
505        operationtype: OperationType::MemoryIntensive,
506        backend_metadata: HashMap::new(),
507    };
508
509    BaseKernel::new(
510        "sgd_optimizer",
511        cuda_source,
512        cuda_source,
513        "",
514        "",
515        cuda_source,
516        metadata,
517    )
518}
519
520/// Create RMSprop optimizer kernel for GPU acceleration
521#[allow(dead_code)]
522fn create_rmsprop_optimizer_kernel() -> BaseKernel {
523    let cuda_source = include_str!("rmsprop_optimizer.cu");
524    let metadata = KernelMetadata {
525        workgroup_size: [256, 1, 1],
526        local_memory_usage: 0,
527        supports_tensor_cores: false,
528        operationtype: OperationType::ComputeIntensive,
529        backend_metadata: HashMap::new(),
530    };
531
532    BaseKernel::new(
533        "rmsprop_optimizer",
534        cuda_source,
535        cuda_source,
536        "",
537        "",
538        cuda_source,
539        metadata,
540    )
541}
542
543/// Create Adagrad optimizer kernel for GPU acceleration
544#[allow(dead_code)]
545fn create_adagrad_optimizer_kernel() -> BaseKernel {
546    let cuda_source = include_str!("adagrad_optimizer.cu");
547    let metadata = KernelMetadata {
548        workgroup_size: [256, 1, 1],
549        local_memory_usage: 0,
550        supports_tensor_cores: false,
551        operationtype: OperationType::ComputeIntensive,
552        backend_metadata: HashMap::new(),
553    };
554
555    BaseKernel::new(
556        "adagrad_optimizer",
557        cuda_source,
558        cuda_source,
559        "",
560        "",
561        cuda_source,
562        metadata,
563    )
564}
565
566/// Create LAMB optimizer kernel for GPU acceleration
567#[allow(dead_code)]
568fn create_lamb_optimizer_kernel() -> BaseKernel {
569    let cuda_source = include_str!("lamb_optimizer.cu");
570    let metadata = KernelMetadata {
571        workgroup_size: [256, 1, 1],
572        local_memory_usage: 0,
573        supports_tensor_cores: false,
574        operationtype: OperationType::ComputeIntensive,
575        backend_metadata: HashMap::new(),
576    };
577
578    BaseKernel::new(
579        "lamb_optimizer",
580        cuda_source,
581        cuda_source,
582        "",
583        "",
584        cuda_source,
585        metadata,
586    )
587}
588
589/// Create memory copy kernel for GPU acceleration
590#[allow(dead_code)]
591fn create_memcpy_kernel() -> BaseKernel {
592    let cuda_source = include_str!("memcpy.cu");
593    let metadata = KernelMetadata {
594        workgroup_size: [256, 1, 1],
595        local_memory_usage: 0,
596        supports_tensor_cores: false,
597        operationtype: OperationType::MemoryIntensive,
598        backend_metadata: HashMap::new(),
599    };
600
601    BaseKernel::new(
602        "memcpy",
603        cuda_source,
604        cuda_source,
605        "",
606        "",
607        cuda_source,
608        metadata,
609    )
610}
611
612/// Create fill kernel for GPU acceleration
613#[allow(dead_code)]
614fn create_fill_kernel() -> BaseKernel {
615    let cuda_source = include_str!("fill.cu");
616    let metadata = KernelMetadata {
617        workgroup_size: [256, 1, 1],
618        local_memory_usage: 0,
619        supports_tensor_cores: false,
620        operationtype: OperationType::MemoryIntensive,
621        backend_metadata: HashMap::new(),
622    };
623
624    BaseKernel::new(
625        "fill",
626        cuda_source,
627        cuda_source,
628        "",
629        "",
630        cuda_source,
631        metadata,
632    )
633}
634
635/// Create reduce sum kernel for GPU acceleration
636#[allow(dead_code)]
637fn create_reduce_sum_kernel() -> BaseKernel {
638    let cuda_source = include_str!("reduce_sum.cu");
639    let metadata = KernelMetadata {
640        workgroup_size: [256, 1, 1],
641        local_memory_usage: 1024, // Shared memory for reduction
642        supports_tensor_cores: false,
643        operationtype: OperationType::ComputeIntensive,
644        backend_metadata: HashMap::new(),
645    };
646
647    BaseKernel::new(
648        "reduce_sum",
649        cuda_source,
650        cuda_source,
651        "",
652        "",
653        cuda_source,
654        metadata,
655    )
656}
657
658/// Create reduce max kernel for GPU acceleration
659#[allow(dead_code)]
660fn create_reduce_max_kernel() -> BaseKernel {
661    let cuda_source = include_str!("reduce_max.cu");
662    let metadata = KernelMetadata {
663        workgroup_size: [256, 1, 1],
664        local_memory_usage: 1024, // Shared memory for reduction
665        supports_tensor_cores: false,
666        operationtype: OperationType::ComputeIntensive,
667        backend_metadata: HashMap::new(),
668    };
669
670    BaseKernel::new(
671        "reduce_max",
672        cuda_source,
673        cuda_source,
674        "",
675        "",
676        cuda_source,
677        metadata,
678    )
679}