1use std::collections::HashMap;
7use std::fmt;
8
9pub mod blas;
10pub mod complex;
11pub mod elementwise;
12pub mod ml;
13pub mod reduction;
14pub mod transform;
15
16use crate::gpu::{GpuBackend, GpuError};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum DataType {
21 Float32,
23 Float64,
25 Int32,
27 UInt32,
29 Float16,
31 BFloat16,
33}
34
35impl fmt::Display for DataType {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 match self {
38 DataType::Float32 => write!(f, "f32"),
39 DataType::Float64 => write!(f, "f64"),
40 DataType::Int32 => write!(f, "i32"),
41 DataType::UInt32 => write!(f, "u32"),
42 DataType::Float16 => write!(f, "f16"),
43 DataType::BFloat16 => write!(f, "bf16"),
44 }
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum OperationType {
51 ComputeIntensive,
53 MemoryIntensive,
55 Balanced,
57}
58
59#[derive(Debug, Clone)]
61pub struct KernelMetadata {
62 pub workgroup_size: [u32; 3],
64 pub local_memory_usage: usize,
66 pub supports_tensor_cores: bool,
68 pub operationtype: OperationType,
70 pub backend_metadata: HashMap<String, String>,
72}
73
74impl Default for KernelMetadata {
75 fn default() -> Self {
76 Self {
77 workgroup_size: [16, 16, 1],
78 local_memory_usage: 0,
79 supports_tensor_cores: false,
80 operationtype: OperationType::Balanced,
81 backend_metadata: HashMap::new(),
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct KernelParams {
89 pub datatype: DataType,
91 pub input_dims: Vec<usize>,
93 pub output_dims: Vec<usize>,
95 pub numeric_params: HashMap<String, f64>,
97 pub string_params: HashMap<String, String>,
99}
100
101impl KernelParams {
102 pub fn new(datatype: DataType) -> Self {
104 Self {
105 datatype,
106 input_dims: Vec::new(),
107 output_dims: Vec::new(),
108 numeric_params: HashMap::new(),
109 string_params: HashMap::new(),
110 }
111 }
112
113 pub fn with_input_dims(mut self, dims: Vec<usize>) -> Self {
115 self.input_dims = dims;
116 self
117 }
118
119 pub fn with_output_dims(mut self, dims: Vec<usize>) -> Self {
121 self.output_dims = dims;
122 self
123 }
124
125 pub fn with_numeric_param(mut self, name: &str, value: f64) -> Self {
127 self.numeric_params.insert(name.to_string(), value);
128 self
129 }
130
131 pub fn with_string_param(mut self, name: &str, value: &str) -> Self {
133 self.string_params
134 .insert(name.to_string(), value.to_string());
135 self
136 }
137}
138
139pub trait GpuKernel: Send + Sync {
141 fn name(&self) -> &str;
143
144 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError>;
146
147 fn metadata(&self) -> KernelMetadata;
149
150 fn can_specialize(&self, params: &KernelParams) -> bool;
152
153 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError>;
155}
156
157pub struct BaseKernel {
159 name: String,
160 cuda_source: String,
161 rocm_source: String,
162 wgpu_source: String,
163 metal_source: String,
164 opencl_source: String,
165 metadata: KernelMetadata,
166}
167
168impl BaseKernel {
169 pub fn new(
171 name: &str,
172 cuda_source: &str,
173 rocm_source: &str,
174 wgpu_source: &str,
175 metal_source: &str,
176 opencl_source: &str,
177 metadata: KernelMetadata,
178 ) -> Self {
179 Self {
180 name: name.to_string(),
181 cuda_source: cuda_source.to_string(),
182 rocm_source: rocm_source.to_string(),
183 wgpu_source: wgpu_source.to_string(),
184 metal_source: metal_source.to_string(),
185 opencl_source: opencl_source.to_string(),
186 metadata,
187 }
188 }
189}
190
191impl GpuKernel for BaseKernel {
192 fn name(&self) -> &str {
193 &self.name
194 }
195
196 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
197 match backend {
198 GpuBackend::Cuda => Ok(self.cuda_source.clone()),
199 GpuBackend::Rocm => Ok(self.rocm_source.clone()),
200 GpuBackend::Wgpu => Ok(self.wgpu_source.clone()),
201 GpuBackend::Metal => Ok(self.metal_source.clone()),
202 GpuBackend::OpenCL => Ok(self.opencl_source.clone()),
203 _ => Err(GpuError::UnsupportedBackend(backend)),
204 }
205 }
206
207 fn metadata(&self) -> KernelMetadata {
208 self.metadata.clone()
209 }
210
211 fn can_specialize(&self, params: &KernelParams) -> bool {
212 false }
214
215 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
216 Err(GpuError::SpecializationNotSupported)
217 }
218}
219
220pub struct KernelRegistry {
222 kernels: HashMap<String, Box<dyn GpuKernel>>,
223}
224
225impl KernelRegistry {
226 pub fn new() -> Self {
228 Self {
229 kernels: HashMap::new(),
230 }
231 }
232
233 pub fn with_default_kernels() -> Self {
235 let mut registry = Self::new();
236
237 registry.register(Box::new(blas::gemm::GemmKernel::new()));
239 registry.register(Box::new(blas::axpy::AxpyKernel::new()));
240 registry.register(Box::new(blas::gemv::GemvKernel::new()));
241
242 registry.register(Box::new(elementwise::ElementwiseAddKernel::new()));
244 registry.register(Box::new(elementwise::ElementwiseSubKernel::new()));
245 registry.register(Box::new(elementwise::ElementwiseMulKernel::new()));
246 registry.register(Box::new(elementwise::ElementwiseDivKernel::new()));
247 registry.register(Box::new(elementwise::ElementwisePowKernel::new()));
248 registry.register(Box::new(elementwise::ElementwiseSqrtKernel::new()));
249 registry.register(Box::new(elementwise::ElementwiseExpKernel::new()));
250 registry.register(Box::new(elementwise::ElementwiseLogKernel::new()));
251
252 registry.register(Box::new(create_adam_optimizer_kernel()));
254 registry.register(Box::new(create_sgd_optimizer_kernel()));
255 registry.register(Box::new(create_rmsprop_optimizer_kernel()));
256 registry.register(Box::new(create_adagrad_optimizer_kernel()));
257 registry.register(Box::new(create_lamb_optimizer_kernel()));
258
259 registry.register(Box::new(create_memcpy_kernel()));
261 registry.register(Box::new(create_fill_kernel()));
262 registry.register(Box::new(create_reduce_sum_kernel()));
263 registry.register(Box::new(create_reduce_max_kernel()));
264
265 registry.register(Box::new(transform::fft::FftKernel::new()));
267 registry.register(Box::new(transform::convolution::Conv1dKernel::new()));
268 registry.register(Box::new(transform::convolution::Conv2dKernel::new()));
269
270 registry.register(Box::new(reduction::sum::SumKernel::new()));
272 registry.register(Box::new(reduction::norm::NormKernel::new()));
273 registry.register(Box::new(reduction::min_max::MinKernel::new()));
274 registry.register(Box::new(reduction::min_max::MaxKernel::new()));
275 registry.register(Box::new(reduction::mean::MeanKernel::new()));
276 registry.register(Box::new(reduction::std_dev::StdDevKernel::new()));
277
278 registry.register(Box::new(ml::activation::ReluKernel::new()));
280 registry.register(Box::new(ml::activation::SigmoidKernel::new()));
281 registry.register(Box::new(ml::activation::TanhKernel::new()));
282 registry.register(Box::new(ml::softmax::SoftmaxKernel::new()));
283 registry.register(Box::new(ml::pooling::MaxPoolKernel::new()));
284 registry.register(Box::new(ml::pooling::AvgPoolKernel::new()));
285
286 registry.register(Box::new(complex::ComplexMultiplyKernel::new()));
288 registry.register(Box::new(complex::ComplexConjugateKernel::new()));
289 registry.register(Box::new(complex::ComplexMatMulKernel::new()));
290
291 registry.register(Box::new(create_rk4_stage1_kernel()));
293 registry.register(Box::new(create_rk4_stage2_kernel()));
294 registry.register(Box::new(create_rk4_stage3_kernel()));
295 registry.register(Box::new(create_rk4_stage4_kernel()));
296 registry.register(Box::new(create_rk4_combine_kernel()));
297 registry.register(Box::new(createerror_estimate_kernel()));
298
299 registry
300 }
301
302 pub fn register(&mut self, kernel: Box<dyn GpuKernel>) {
304 self.kernels.insert(kernel.name().to_string(), kernel);
305 }
306
307 pub fn get(&self, name: &str) -> Option<&dyn GpuKernel> {
309 self.kernels.get(name).map(|k| k.as_ref())
310 }
311
312 pub fn get_specialized(
314 &self,
315 name: &str,
316 params: &KernelParams,
317 ) -> Result<Box<dyn GpuKernel>, GpuError> {
318 let kernel = self
319 .get(name)
320 .ok_or_else(|| GpuError::KernelNotFound(name.to_string()))?;
321
322 if kernel.can_specialize(params) {
323 kernel.specialize(params)
324 } else {
325 Err(GpuError::SpecializationNotSupported)
326 }
327 }
328}
329
330impl Default for KernelRegistry {
331 fn default() -> Self {
332 Self::with_default_kernels()
333 }
334}
335
336#[allow(dead_code)]
338fn create_rk4_stage1_kernel() -> BaseKernel {
339 let cuda_source = include_str!("rk4_stage1.cu");
340 let metadata = KernelMetadata {
341 workgroup_size: [256, 1, 1],
342 local_memory_usage: 0,
343 supports_tensor_cores: false,
344 operationtype: OperationType::ComputeIntensive,
345 backend_metadata: HashMap::new(),
346 };
347
348 BaseKernel::new(
349 "rk4_stage1",
350 cuda_source,
351 cuda_source, "", "", cuda_source, metadata,
356 )
357}
358
359#[allow(dead_code)]
361fn create_rk4_stage2_kernel() -> BaseKernel {
362 let cuda_source = include_str!("rk4_stage2.cu");
363 let metadata = KernelMetadata {
364 workgroup_size: [256, 1, 1],
365 local_memory_usage: 0,
366 supports_tensor_cores: false,
367 operationtype: OperationType::ComputeIntensive,
368 backend_metadata: HashMap::new(),
369 };
370
371 BaseKernel::new(
372 "rk4_stage2",
373 cuda_source,
374 cuda_source,
375 "",
376 "",
377 cuda_source,
378 metadata,
379 )
380}
381
382#[allow(dead_code)]
384fn create_rk4_stage3_kernel() -> BaseKernel {
385 let cuda_source = include_str!("rk4_stage3.cu");
386 let metadata = KernelMetadata {
387 workgroup_size: [256, 1, 1],
388 local_memory_usage: 0,
389 supports_tensor_cores: false,
390 operationtype: OperationType::ComputeIntensive,
391 backend_metadata: HashMap::new(),
392 };
393
394 BaseKernel::new(
395 "rk4_stage3",
396 cuda_source,
397 cuda_source,
398 "",
399 "",
400 cuda_source,
401 metadata,
402 )
403}
404
405#[allow(dead_code)]
407fn create_rk4_stage4_kernel() -> BaseKernel {
408 let cuda_source = include_str!("rk4_stage4.cu");
409 let metadata = KernelMetadata {
410 workgroup_size: [256, 1, 1],
411 local_memory_usage: 0,
412 supports_tensor_cores: false,
413 operationtype: OperationType::ComputeIntensive,
414 backend_metadata: HashMap::new(),
415 };
416
417 BaseKernel::new(
418 "rk4_stage4",
419 cuda_source,
420 cuda_source,
421 "",
422 "",
423 cuda_source,
424 metadata,
425 )
426}
427
428#[allow(dead_code)]
430fn create_rk4_combine_kernel() -> BaseKernel {
431 let cuda_source = include_str!("rk4_combine.cu");
432 let metadata = KernelMetadata {
433 workgroup_size: [256, 1, 1],
434 local_memory_usage: 0,
435 supports_tensor_cores: false,
436 operationtype: OperationType::MemoryIntensive,
437 backend_metadata: HashMap::new(),
438 };
439
440 BaseKernel::new(
441 "rk4_combine",
442 cuda_source,
443 cuda_source,
444 "",
445 "",
446 cuda_source,
447 metadata,
448 )
449}
450
451#[allow(dead_code)]
453fn createerror_estimate_kernel() -> BaseKernel {
454 let cuda_source = include_str!("error_estimate.cu");
455 let metadata = KernelMetadata {
456 workgroup_size: [256, 1, 1],
457 local_memory_usage: 1024, supports_tensor_cores: false,
459 operationtype: OperationType::ComputeIntensive,
460 backend_metadata: HashMap::new(),
461 };
462
463 BaseKernel::new(
464 "error_estimate",
465 cuda_source,
466 cuda_source,
467 "",
468 "",
469 cuda_source,
470 metadata,
471 )
472}
473
474#[allow(dead_code)]
476fn create_adam_optimizer_kernel() -> BaseKernel {
477 let cuda_source = include_str!("adam_optimizer.cu");
478 let metadata = KernelMetadata {
479 workgroup_size: [256, 1, 1],
480 local_memory_usage: 0,
481 supports_tensor_cores: false,
482 operationtype: OperationType::ComputeIntensive,
483 backend_metadata: HashMap::new(),
484 };
485
486 BaseKernel::new(
487 "adam_optimizer",
488 cuda_source,
489 cuda_source,
490 "",
491 "",
492 cuda_source,
493 metadata,
494 )
495}
496
497#[allow(dead_code)]
499fn create_sgd_optimizer_kernel() -> BaseKernel {
500 let cuda_source = include_str!("sgd_optimizer.cu");
501 let metadata = KernelMetadata {
502 workgroup_size: [256, 1, 1],
503 local_memory_usage: 0,
504 supports_tensor_cores: false,
505 operationtype: OperationType::MemoryIntensive,
506 backend_metadata: HashMap::new(),
507 };
508
509 BaseKernel::new(
510 "sgd_optimizer",
511 cuda_source,
512 cuda_source,
513 "",
514 "",
515 cuda_source,
516 metadata,
517 )
518}
519
520#[allow(dead_code)]
522fn create_rmsprop_optimizer_kernel() -> BaseKernel {
523 let cuda_source = include_str!("rmsprop_optimizer.cu");
524 let metadata = KernelMetadata {
525 workgroup_size: [256, 1, 1],
526 local_memory_usage: 0,
527 supports_tensor_cores: false,
528 operationtype: OperationType::ComputeIntensive,
529 backend_metadata: HashMap::new(),
530 };
531
532 BaseKernel::new(
533 "rmsprop_optimizer",
534 cuda_source,
535 cuda_source,
536 "",
537 "",
538 cuda_source,
539 metadata,
540 )
541}
542
543#[allow(dead_code)]
545fn create_adagrad_optimizer_kernel() -> BaseKernel {
546 let cuda_source = include_str!("adagrad_optimizer.cu");
547 let metadata = KernelMetadata {
548 workgroup_size: [256, 1, 1],
549 local_memory_usage: 0,
550 supports_tensor_cores: false,
551 operationtype: OperationType::ComputeIntensive,
552 backend_metadata: HashMap::new(),
553 };
554
555 BaseKernel::new(
556 "adagrad_optimizer",
557 cuda_source,
558 cuda_source,
559 "",
560 "",
561 cuda_source,
562 metadata,
563 )
564}
565
566#[allow(dead_code)]
568fn create_lamb_optimizer_kernel() -> BaseKernel {
569 let cuda_source = include_str!("lamb_optimizer.cu");
570 let metadata = KernelMetadata {
571 workgroup_size: [256, 1, 1],
572 local_memory_usage: 0,
573 supports_tensor_cores: false,
574 operationtype: OperationType::ComputeIntensive,
575 backend_metadata: HashMap::new(),
576 };
577
578 BaseKernel::new(
579 "lamb_optimizer",
580 cuda_source,
581 cuda_source,
582 "",
583 "",
584 cuda_source,
585 metadata,
586 )
587}
588
589#[allow(dead_code)]
591fn create_memcpy_kernel() -> BaseKernel {
592 let cuda_source = include_str!("memcpy.cu");
593 let metadata = KernelMetadata {
594 workgroup_size: [256, 1, 1],
595 local_memory_usage: 0,
596 supports_tensor_cores: false,
597 operationtype: OperationType::MemoryIntensive,
598 backend_metadata: HashMap::new(),
599 };
600
601 BaseKernel::new(
602 "memcpy",
603 cuda_source,
604 cuda_source,
605 "",
606 "",
607 cuda_source,
608 metadata,
609 )
610}
611
612#[allow(dead_code)]
614fn create_fill_kernel() -> BaseKernel {
615 let cuda_source = include_str!("fill.cu");
616 let metadata = KernelMetadata {
617 workgroup_size: [256, 1, 1],
618 local_memory_usage: 0,
619 supports_tensor_cores: false,
620 operationtype: OperationType::MemoryIntensive,
621 backend_metadata: HashMap::new(),
622 };
623
624 BaseKernel::new(
625 "fill",
626 cuda_source,
627 cuda_source,
628 "",
629 "",
630 cuda_source,
631 metadata,
632 )
633}
634
635#[allow(dead_code)]
637fn create_reduce_sum_kernel() -> BaseKernel {
638 let cuda_source = include_str!("reduce_sum.cu");
639 let metadata = KernelMetadata {
640 workgroup_size: [256, 1, 1],
641 local_memory_usage: 1024, supports_tensor_cores: false,
643 operationtype: OperationType::ComputeIntensive,
644 backend_metadata: HashMap::new(),
645 };
646
647 BaseKernel::new(
648 "reduce_sum",
649 cuda_source,
650 cuda_source,
651 "",
652 "",
653 cuda_source,
654 metadata,
655 )
656}
657
658#[allow(dead_code)]
660fn create_reduce_max_kernel() -> BaseKernel {
661 let cuda_source = include_str!("reduce_max.cu");
662 let metadata = KernelMetadata {
663 workgroup_size: [256, 1, 1],
664 local_memory_usage: 1024, supports_tensor_cores: false,
666 operationtype: OperationType::ComputeIntensive,
667 backend_metadata: HashMap::new(),
668 };
669
670 BaseKernel::new(
671 "reduce_max",
672 cuda_source,
673 cuda_source,
674 "",
675 "",
676 cuda_source,
677 metadata,
678 )
679}