1use crate::Device;
4use torsh_core::dtype::DType;
5
6#[cfg(not(feature = "std"))]
7use alloc::{boxed::Box, string::String, vec::Vec};
8
9#[derive(Debug)]
11pub struct Kernel {
12 pub id: usize,
14
15 pub device: Device,
17
18 pub name: String,
20
21 pub descriptor: KernelDescriptor,
23
24 pub handle: KernelHandle,
26
27 pub metadata: KernelMetadata,
29}
30
31impl Kernel {
32 pub fn new(
34 id: usize,
35 device: Device,
36 name: String,
37 descriptor: KernelDescriptor,
38 handle: KernelHandle,
39 metadata: KernelMetadata,
40 ) -> Self {
41 Self {
42 id,
43 device,
44 name,
45 descriptor,
46 handle,
47 metadata,
48 }
49 }
50
51 pub fn id(&self) -> usize {
53 self.id
54 }
55
56 pub fn name(&self) -> &str {
58 &self.name
59 }
60
61 pub fn device(&self) -> &Device {
63 &self.device
64 }
65
66 pub fn metadata(&self) -> &KernelMetadata {
68 &self.metadata
69 }
70
71 pub fn handle(&self) -> &KernelHandle {
73 &self.handle
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct KernelDescriptor {
80 pub name: String,
82
83 pub source: KernelSource,
85
86 pub compile_options: Vec<String>,
88
89 pub parameters: Vec<KernelParameter>,
91
92 pub workgroup_size_hint: Option<(u32, u32, u32)>,
94
95 pub cache: bool,
97}
98
99impl KernelDescriptor {
100 pub fn new(name: String, source: KernelSource) -> Self {
102 Self {
103 name,
104 source,
105 compile_options: Vec::new(),
106 parameters: Vec::new(),
107 workgroup_size_hint: None,
108 cache: true,
109 }
110 }
111
112 pub fn with_compile_option(mut self, option: String) -> Self {
114 self.compile_options.push(option);
115 self
116 }
117
118 pub fn with_parameter(mut self, param: KernelParameter) -> Self {
120 self.parameters.push(param);
121 self
122 }
123
124 pub fn with_workgroup_size_hint(mut self, size: (u32, u32, u32)) -> Self {
126 self.workgroup_size_hint = Some(size);
127 self
128 }
129
130 pub fn without_cache(mut self) -> Self {
132 self.cache = false;
133 self
134 }
135}
136
137#[derive(Debug, Clone)]
139pub enum KernelSource {
140 Source {
142 code: String,
143 language: KernelLanguage,
144 },
145
146 Bytecode {
148 data: Vec<u8>,
149 format: BytecodeFormat,
150 },
151
152 SpirV { data: Vec<u32> },
154
155 Binary { data: Vec<u8>, platform: String },
157}
158
159#[derive(Debug, Clone, PartialEq, Eq)]
161pub enum KernelLanguage {
162 Wgsl,
164
165 Hlsl,
167
168 Glsl,
170
171 Metal,
173
174 Cuda,
176
177 OpenCl,
179
180 Custom(String),
182}
183
184#[derive(Debug, Clone, PartialEq, Eq)]
186pub enum BytecodeFormat {
187 SpirV,
189
190 Dxil,
192
193 MetalAir,
195
196 Ptx,
198
199 Custom(String),
201}
202
203#[derive(Debug, Clone)]
205pub struct KernelParameter {
206 pub name: String,
208
209 pub param_type: KernelParameterType,
211
212 pub binding: Option<u32>,
214
215 pub readonly: bool,
217}
218
219impl KernelParameter {
220 pub fn buffer(name: String, dtype: DType, readonly: bool) -> Self {
222 Self {
223 name,
224 param_type: KernelParameterType::Buffer { dtype },
225 binding: None,
226 readonly,
227 }
228 }
229
230 pub fn uniform(name: String, dtype: DType) -> Self {
232 Self {
233 name,
234 param_type: KernelParameterType::Uniform { dtype },
235 binding: None,
236 readonly: true,
237 }
238 }
239
240 pub fn with_binding(mut self, binding: u32) -> Self {
242 self.binding = Some(binding);
243 self
244 }
245}
246
247#[derive(Debug, Clone)]
249pub enum KernelParameterType {
250 Buffer { dtype: DType },
252
253 Uniform { dtype: DType },
255
256 Texture { dimensions: u32, dtype: DType },
258
259 Sampler,
261
262 Scalar { dtype: DType },
264}
265
266#[derive(Debug)]
268pub enum KernelHandle {
269 Cpu { function: *const () },
271
272 #[cfg(feature = "cuda")]
274 Cuda { module: u64, function: u64 },
275
276 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
278 Metal { library_id: u64, function_id: u64 },
279
280 #[cfg(feature = "webgpu")]
282 WebGpu {
283 shader_module_id: String,
284 entry_point: String,
285 },
286
287 Generic {
289 handle: Box<dyn std::any::Any + Send + Sync>,
290 },
291}
292
293impl Clone for KernelHandle {
294 fn clone(&self) -> Self {
295 match self {
296 KernelHandle::Cpu { function } => KernelHandle::Cpu {
297 function: *function,
298 },
299 #[cfg(feature = "cuda")]
300 KernelHandle::Cuda { module, function } => KernelHandle::Cuda {
301 module: *module,
302 function: *function,
303 },
304 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
305 KernelHandle::Metal {
306 library_id,
307 function_id,
308 } => KernelHandle::Metal {
309 library_id: *library_id,
310 function_id: *function_id,
311 },
312 #[cfg(feature = "webgpu")]
313 KernelHandle::WebGpu {
314 shader_module_id,
315 entry_point,
316 } => KernelHandle::WebGpu {
317 shader_module_id: shader_module_id.clone(),
318 entry_point: entry_point.clone(),
319 },
320 KernelHandle::Generic { .. } => {
321 panic!("Cannot clone Generic kernel handles")
325 }
326 }
327 }
328}
329
330unsafe impl Send for KernelHandle {}
331unsafe impl Sync for KernelHandle {}
332
333#[derive(Debug, Clone)]
335pub struct KernelMetadata {
336 pub compile_time_ms: f64,
338
339 pub binary_size: usize,
341
342 pub registers_per_thread: Option<u32>,
344
345 pub shared_memory_usage: Option<usize>,
347
348 pub max_workgroup_size: Option<(u32, u32, u32)>,
350
351 pub compiler_version: String,
353
354 pub warnings: Vec<String>,
356
357 pub performance_hints: Vec<String>,
359}
360
361impl Default for KernelMetadata {
362 fn default() -> Self {
363 Self {
364 compile_time_ms: 0.0,
365 binary_size: 0,
366 registers_per_thread: None,
367 shared_memory_usage: None,
368 max_workgroup_size: None,
369 compiler_version: "Unknown".to_string(),
370 warnings: Vec::new(),
371 performance_hints: Vec::new(),
372 }
373 }
374}
375
376#[derive(Debug, Clone)]
378pub struct KernelLaunchConfig {
379 pub workgroup_size: (u32, u32, u32),
381
382 pub workgroup_count: (u32, u32, u32),
384
385 pub shared_memory_size: Option<usize>,
387
388 pub stream_id: Option<usize>,
390}
391
392impl KernelLaunchConfig {
393 pub fn linear(global_size: u32, workgroup_size: Option<u32>) -> Self {
395 let wg_size = workgroup_size.unwrap_or(256);
396 let wg_count = global_size.div_ceil(wg_size);
397
398 Self {
399 workgroup_size: (wg_size, 1, 1),
400 workgroup_count: (wg_count, 1, 1),
401 shared_memory_size: None,
402 stream_id: None,
403 }
404 }
405
406 pub fn grid_2d(global_size: (u32, u32), workgroup_size: Option<(u32, u32)>) -> Self {
408 let wg_size = workgroup_size.unwrap_or((16, 16));
409 let wg_count = (
410 global_size.0.div_ceil(wg_size.0),
411 global_size.1.div_ceil(wg_size.1),
412 );
413
414 Self {
415 workgroup_size: (wg_size.0, wg_size.1, 1),
416 workgroup_count: (wg_count.0, wg_count.1, 1),
417 shared_memory_size: None,
418 stream_id: None,
419 }
420 }
421
422 pub fn grid_3d(global_size: (u32, u32, u32), workgroup_size: Option<(u32, u32, u32)>) -> Self {
424 let wg_size = workgroup_size.unwrap_or((8, 8, 8));
425 let wg_count = (
426 global_size.0.div_ceil(wg_size.0),
427 global_size.1.div_ceil(wg_size.1),
428 global_size.2.div_ceil(wg_size.2),
429 );
430
431 Self {
432 workgroup_size: wg_size,
433 workgroup_count: wg_count,
434 shared_memory_size: None,
435 stream_id: None,
436 }
437 }
438
439 pub fn with_shared_memory(mut self, size: usize) -> Self {
441 self.shared_memory_size = Some(size);
442 self
443 }
444
445 pub fn with_stream(mut self, stream_id: usize) -> Self {
447 self.stream_id = Some(stream_id);
448 self
449 }
450
451 pub fn total_threads(&self) -> u64 {
453 (self.workgroup_size.0 as u64)
454 * (self.workgroup_size.1 as u64)
455 * (self.workgroup_size.2 as u64)
456 * (self.workgroup_count.0 as u64)
457 * (self.workgroup_count.1 as u64)
458 * (self.workgroup_count.2 as u64)
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use crate::device::{Device, DeviceInfo};
466 use torsh_core::{device::DeviceType, dtype::DType};
467
468 fn create_test_device() -> Device {
469 let info = DeviceInfo::default();
470 Device::new(0, DeviceType::Cpu, "Test CPU".to_string(), info)
471 }
472
473 #[test]
474 fn test_kernel_descriptor_creation() {
475 let source = KernelSource::Source {
476 code: "void main() {}".to_string(),
477 language: KernelLanguage::Hlsl,
478 };
479
480 let desc = KernelDescriptor::new("test_kernel".to_string(), source);
481
482 assert_eq!(desc.name, "test_kernel");
483 assert!(desc.compile_options.is_empty());
484 assert!(desc.parameters.is_empty());
485 assert_eq!(desc.workgroup_size_hint, None);
486 assert!(desc.cache);
487 }
488
489 #[test]
490 fn test_kernel_descriptor_builder() {
491 let source = KernelSource::Source {
492 code: "void main() {}".to_string(),
493 language: KernelLanguage::Cuda,
494 };
495
496 let param = KernelParameter::buffer("input".to_string(), DType::F32, true);
497
498 let desc = KernelDescriptor::new("complex_kernel".to_string(), source)
499 .with_compile_option("-O3".to_string())
500 .with_compile_option("--fast-math".to_string())
501 .with_parameter(param)
502 .with_workgroup_size_hint((256, 1, 1))
503 .without_cache();
504
505 assert_eq!(desc.name, "complex_kernel");
506 assert_eq!(desc.compile_options.len(), 2);
507 assert!(desc.compile_options.contains(&"-O3".to_string()));
508 assert!(desc.compile_options.contains(&"--fast-math".to_string()));
509 assert_eq!(desc.parameters.len(), 1);
510 assert_eq!(desc.workgroup_size_hint, Some((256, 1, 1)));
511 assert!(!desc.cache);
512 }
513
514 #[test]
515 fn test_kernel_source_variants() {
516 let source1 = KernelSource::Source {
517 code: "vertex main() {}".to_string(),
518 language: KernelLanguage::Metal,
519 };
520
521 let source2 = KernelSource::Bytecode {
522 data: vec![0x12, 0x34, 0x56, 0x78],
523 format: BytecodeFormat::SpirV,
524 };
525
526 let source3 = KernelSource::SpirV {
527 data: vec![0x07230203, 0x00010000],
528 };
529
530 let source4 = KernelSource::Binary {
531 data: vec![0xCA, 0xFE, 0xBA, 0xBE],
532 platform: "cuda".to_string(),
533 };
534
535 match source1 {
537 KernelSource::Source { language, .. } => assert_eq!(language, KernelLanguage::Metal),
538 _ => panic!("Wrong variant"),
539 }
540
541 match source2 {
542 KernelSource::Bytecode { format, .. } => assert_eq!(format, BytecodeFormat::SpirV),
543 _ => panic!("Wrong variant"),
544 }
545
546 match source3 {
547 KernelSource::SpirV { .. } => {}
548 _ => panic!("Wrong variant"),
549 }
550
551 match source4 {
552 KernelSource::Binary { platform, .. } => assert_eq!(platform, "cuda"),
553 _ => panic!("Wrong variant"),
554 }
555 }
556
557 #[test]
558 fn test_kernel_language_variants() {
559 let languages = [
560 KernelLanguage::Wgsl,
561 KernelLanguage::Hlsl,
562 KernelLanguage::Glsl,
563 KernelLanguage::Metal,
564 KernelLanguage::Cuda,
565 KernelLanguage::OpenCl,
566 KernelLanguage::Custom("MyLang".to_string()),
567 ];
568
569 for (i, lang1) in languages.iter().enumerate() {
571 for (j, lang2) in languages.iter().enumerate() {
572 if i != j {
573 assert_ne!(lang1, lang2);
574 }
575 }
576 }
577 }
578
579 #[test]
580 fn test_bytecode_format_variants() {
581 let formats = [
582 BytecodeFormat::SpirV,
583 BytecodeFormat::Dxil,
584 BytecodeFormat::MetalAir,
585 BytecodeFormat::Ptx,
586 BytecodeFormat::Custom("MyFormat".to_string()),
587 ];
588
589 for (i, format1) in formats.iter().enumerate() {
591 for (j, format2) in formats.iter().enumerate() {
592 if i != j {
593 assert_ne!(format1, format2);
594 }
595 }
596 }
597 }
598
599 #[test]
600 fn test_kernel_parameter_creation() {
601 let buffer_param = KernelParameter::buffer("data".to_string(), DType::F32, false);
602 assert_eq!(buffer_param.name, "data");
603 assert!(!buffer_param.readonly);
604 assert_eq!(buffer_param.binding, None);
605 match buffer_param.param_type {
606 KernelParameterType::Buffer { dtype } => assert_eq!(dtype, DType::F32),
607 _ => panic!("Wrong parameter type"),
608 }
609
610 let uniform_param = KernelParameter::uniform("scale".to_string(), DType::F32);
611 assert_eq!(uniform_param.name, "scale");
612 assert!(uniform_param.readonly);
613 match uniform_param.param_type {
614 KernelParameterType::Uniform { dtype } => assert_eq!(dtype, DType::F32),
615 _ => panic!("Wrong parameter type"),
616 }
617
618 let bound_param = buffer_param.with_binding(0);
619 assert_eq!(bound_param.binding, Some(0));
620 }
621
622 #[test]
623 fn test_kernel_parameter_types() {
624 let buffer_type = KernelParameterType::Buffer { dtype: DType::I32 };
625 let uniform_type = KernelParameterType::Uniform { dtype: DType::F64 };
626 let texture_type = KernelParameterType::Texture {
627 dimensions: 2,
628 dtype: DType::F32,
629 };
630 let sampler_type = KernelParameterType::Sampler;
631 let scalar_type = KernelParameterType::Scalar { dtype: DType::U8 };
632
633 assert_ne!(
635 std::mem::discriminant(&buffer_type),
636 std::mem::discriminant(&uniform_type)
637 );
638 assert_ne!(
639 std::mem::discriminant(&uniform_type),
640 std::mem::discriminant(&texture_type)
641 );
642 assert_ne!(
643 std::mem::discriminant(&texture_type),
644 std::mem::discriminant(&sampler_type)
645 );
646 assert_ne!(
647 std::mem::discriminant(&sampler_type),
648 std::mem::discriminant(&scalar_type)
649 );
650 }
651
652 #[test]
653 fn test_kernel_handle_cpu() {
654 let handle = KernelHandle::Cpu {
655 function: std::ptr::null(),
656 };
657
658 match handle {
659 KernelHandle::Cpu { function } => assert!(function.is_null()),
660 _ => panic!("Wrong handle type"),
661 }
662 }
663
664 #[test]
665 fn test_kernel_metadata_default() {
666 let metadata = KernelMetadata::default();
667
668 assert_eq!(metadata.compile_time_ms, 0.0);
669 assert_eq!(metadata.binary_size, 0);
670 assert_eq!(metadata.registers_per_thread, None);
671 assert_eq!(metadata.shared_memory_usage, None);
672 assert_eq!(metadata.max_workgroup_size, None);
673 assert_eq!(metadata.compiler_version, "Unknown");
674 assert!(metadata.warnings.is_empty());
675 assert!(metadata.performance_hints.is_empty());
676 }
677
678 #[test]
679 fn test_kernel_creation() {
680 let device = create_test_device();
681 let source = KernelSource::Source {
682 code: "void main() {}".to_string(),
683 language: KernelLanguage::Hlsl,
684 };
685 let desc = KernelDescriptor::new("test".to_string(), source);
686 let handle = KernelHandle::Cpu {
687 function: std::ptr::null(),
688 };
689 let metadata = KernelMetadata::default();
690
691 let kernel = Kernel::new(
692 1,
693 device.clone(),
694 "test_kernel".to_string(),
695 desc,
696 handle,
697 metadata,
698 );
699
700 assert_eq!(kernel.id(), 1);
701 assert_eq!(kernel.name(), "test_kernel");
702 assert_eq!(kernel.device().id(), device.id());
703 }
704
705 #[test]
706 fn test_kernel_launch_config_linear() {
707 let config = KernelLaunchConfig::linear(1000, Some(64));
708
709 assert_eq!(config.workgroup_size, (64, 1, 1));
710 assert_eq!(config.workgroup_count, (16, 1, 1)); assert_eq!(config.shared_memory_size, None);
712 assert_eq!(config.stream_id, None);
713 assert_eq!(config.total_threads(), 64 * 16); let config_default = KernelLaunchConfig::linear(1000, None);
716 assert_eq!(config_default.workgroup_size, (256, 1, 1));
717 assert_eq!(config_default.workgroup_count, (4, 1, 1)); }
719
720 #[test]
721 fn test_kernel_launch_config_2d() {
722 let config = KernelLaunchConfig::grid_2d((100, 50), Some((10, 5)));
723
724 assert_eq!(config.workgroup_size, (10, 5, 1));
725 assert_eq!(config.workgroup_count, (10, 10, 1)); assert_eq!(config.total_threads(), 10 * 5 * 10 * 10); let config_default = KernelLaunchConfig::grid_2d((100, 50), None);
729 assert_eq!(config_default.workgroup_size, (16, 16, 1));
730 assert_eq!(config_default.workgroup_count, (7, 4, 1)); }
732
733 #[test]
734 fn test_kernel_launch_config_3d() {
735 let config = KernelLaunchConfig::grid_3d((64, 32, 16), Some((8, 4, 2)));
736
737 assert_eq!(config.workgroup_size, (8, 4, 2));
738 assert_eq!(config.workgroup_count, (8, 8, 8)); assert_eq!(config.total_threads(), 8 * 4 * 2 * 8 * 8 * 8); let config_default = KernelLaunchConfig::grid_3d((64, 32, 16), None);
742 assert_eq!(config_default.workgroup_size, (8, 8, 8));
743 assert_eq!(config_default.workgroup_count, (8, 4, 2)); }
745
746 #[test]
747 fn test_kernel_launch_config_builder() {
748 let config = KernelLaunchConfig::linear(1000, Some(128))
749 .with_shared_memory(4096)
750 .with_stream(1);
751
752 assert_eq!(config.shared_memory_size, Some(4096));
753 assert_eq!(config.stream_id, Some(1));
754 }
755}