Skip to main content

torsh_backend/
kernel.rs

1//! Compute kernel abstraction and management
2
3use crate::Device;
4use torsh_core::dtype::DType;
5
6#[cfg(not(feature = "std"))]
7use alloc::{boxed::Box, string::String, vec::Vec};
8
9/// Compute kernel handle
10#[derive(Debug)]
11pub struct Kernel {
12    /// Unique kernel ID
13    pub id: usize,
14
15    /// Device this kernel is compiled for
16    pub device: Device,
17
18    /// Kernel name
19    pub name: String,
20
21    /// Kernel descriptor used for creation
22    pub descriptor: KernelDescriptor,
23
24    /// Backend-specific handle
25    pub handle: KernelHandle,
26
27    /// Kernel metadata
28    pub metadata: KernelMetadata,
29}
30
31impl Kernel {
32    /// Create a new kernel
33    pub fn new(
34        id: usize,
35        device: Device,
36        name: String,
37        descriptor: KernelDescriptor,
38        handle: KernelHandle,
39        metadata: KernelMetadata,
40    ) -> Self {
41        Self {
42            id,
43            device,
44            name,
45            descriptor,
46            handle,
47            metadata,
48        }
49    }
50
51    /// Get kernel ID
52    pub fn id(&self) -> usize {
53        self.id
54    }
55
56    /// Get kernel name
57    pub fn name(&self) -> &str {
58        &self.name
59    }
60
61    /// Get the device this kernel is compiled for
62    pub fn device(&self) -> &Device {
63        &self.device
64    }
65
66    /// Get kernel metadata
67    pub fn metadata(&self) -> &KernelMetadata {
68        &self.metadata
69    }
70
71    /// Get backend-specific handle
72    pub fn handle(&self) -> &KernelHandle {
73        &self.handle
74    }
75}
76
77/// Kernel descriptor for creation
78#[derive(Debug, Clone)]
79pub struct KernelDescriptor {
80    /// Kernel name/entry point
81    pub name: String,
82
83    /// Kernel source code or bytecode
84    pub source: KernelSource,
85
86    /// Compilation options
87    pub compile_options: Vec<String>,
88
89    /// Kernel parameters description
90    pub parameters: Vec<KernelParameter>,
91
92    /// Workgroup size hint
93    pub workgroup_size_hint: Option<(u32, u32, u32)>,
94
95    /// Whether to cache the compiled kernel
96    pub cache: bool,
97}
98
99impl KernelDescriptor {
100    /// Create a new kernel descriptor
101    pub fn new(name: String, source: KernelSource) -> Self {
102        Self {
103            name,
104            source,
105            compile_options: Vec::new(),
106            parameters: Vec::new(),
107            workgroup_size_hint: None,
108            cache: true,
109        }
110    }
111
112    /// Add a compilation option
113    pub fn with_compile_option(mut self, option: String) -> Self {
114        self.compile_options.push(option);
115        self
116    }
117
118    /// Add a kernel parameter
119    pub fn with_parameter(mut self, param: KernelParameter) -> Self {
120        self.parameters.push(param);
121        self
122    }
123
124    /// Set workgroup size hint
125    pub fn with_workgroup_size_hint(mut self, size: (u32, u32, u32)) -> Self {
126        self.workgroup_size_hint = Some(size);
127        self
128    }
129
130    /// Disable kernel caching
131    pub fn without_cache(mut self) -> Self {
132        self.cache = false;
133        self
134    }
135}
136
137/// Kernel source code or bytecode
138#[derive(Debug, Clone)]
139pub enum KernelSource {
140    /// High-level source code (HLSL, GLSL, etc.)
141    Source {
142        code: String,
143        language: KernelLanguage,
144    },
145
146    /// Pre-compiled bytecode
147    Bytecode {
148        data: Vec<u8>,
149        format: BytecodeFormat,
150    },
151
152    /// SPIR-V bytecode
153    SpirV { data: Vec<u32> },
154
155    /// Platform-specific binary
156    Binary { data: Vec<u8>, platform: String },
157}
158
159/// Kernel programming language
160#[derive(Debug, Clone, PartialEq, Eq)]
161pub enum KernelLanguage {
162    /// WGSL (WebGPU Shading Language)
163    Wgsl,
164
165    /// HLSL (High Level Shading Language)
166    Hlsl,
167
168    /// GLSL (OpenGL Shading Language)
169    Glsl,
170
171    /// Metal Shading Language
172    Metal,
173
174    /// CUDA C++
175    Cuda,
176
177    /// OpenCL C
178    OpenCl,
179
180    /// Custom language
181    Custom(String),
182}
183
184/// Bytecode format
185#[derive(Debug, Clone, PartialEq, Eq)]
186pub enum BytecodeFormat {
187    /// SPIR-V
188    SpirV,
189
190    /// DXIL (DirectX Intermediate Language)
191    Dxil,
192
193    /// Metal AIR (Apple Intermediate Representation)
194    MetalAir,
195
196    /// CUDA PTX
197    Ptx,
198
199    /// Custom format
200    Custom(String),
201}
202
203/// Kernel parameter description
204#[derive(Debug, Clone)]
205pub struct KernelParameter {
206    /// Parameter name
207    pub name: String,
208
209    /// Parameter type
210    pub param_type: KernelParameterType,
211
212    /// Parameter binding location
213    pub binding: Option<u32>,
214
215    /// Whether parameter is read-only
216    pub readonly: bool,
217}
218
219impl KernelParameter {
220    /// Create a buffer parameter
221    pub fn buffer(name: String, dtype: DType, readonly: bool) -> Self {
222        Self {
223            name,
224            param_type: KernelParameterType::Buffer { dtype },
225            binding: None,
226            readonly,
227        }
228    }
229
230    /// Create a uniform parameter
231    pub fn uniform(name: String, dtype: DType) -> Self {
232        Self {
233            name,
234            param_type: KernelParameterType::Uniform { dtype },
235            binding: None,
236            readonly: true,
237        }
238    }
239
240    /// Set binding location
241    pub fn with_binding(mut self, binding: u32) -> Self {
242        self.binding = Some(binding);
243        self
244    }
245}
246
247/// Kernel parameter type
248#[derive(Debug, Clone)]
249pub enum KernelParameterType {
250    /// Buffer parameter
251    Buffer { dtype: DType },
252
253    /// Uniform data parameter
254    Uniform { dtype: DType },
255
256    /// Texture/image parameter
257    Texture { dimensions: u32, dtype: DType },
258
259    /// Sampler parameter
260    Sampler,
261
262    /// Scalar parameter
263    Scalar { dtype: DType },
264}
265
266/// Backend-specific kernel handle
267#[derive(Debug)]
268pub enum KernelHandle {
269    /// CPU kernel (function pointer)
270    Cpu { function: *const () },
271
272    /// CUDA kernel
273    #[cfg(feature = "cuda")]
274    Cuda { module: u64, function: u64 },
275
276    /// Metal kernel
277    #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
278    Metal { library_id: u64, function_id: u64 },
279
280    /// WebGPU kernel
281    #[cfg(feature = "webgpu")]
282    WebGpu {
283        shader_module_id: String,
284        entry_point: String,
285    },
286
287    /// Generic handle for custom backends
288    Generic {
289        handle: Box<dyn std::any::Any + Send + Sync>,
290    },
291}
292
293impl Clone for KernelHandle {
294    fn clone(&self) -> Self {
295        match self {
296            KernelHandle::Cpu { function } => KernelHandle::Cpu {
297                function: *function,
298            },
299            #[cfg(feature = "cuda")]
300            KernelHandle::Cuda { module, function } => KernelHandle::Cuda {
301                module: *module,
302                function: *function,
303            },
304            #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
305            KernelHandle::Metal {
306                library_id,
307                function_id,
308            } => KernelHandle::Metal {
309                library_id: *library_id,
310                function_id: *function_id,
311            },
312            #[cfg(feature = "webgpu")]
313            KernelHandle::WebGpu {
314                shader_module_id,
315                entry_point,
316            } => KernelHandle::WebGpu {
317                shader_module_id: shader_module_id.clone(),
318                entry_point: entry_point.clone(),
319            },
320            KernelHandle::Generic { .. } => {
321                // For Generic handles, we can't actually clone the Box<dyn Any>
322                // This is a limitation - in practice, backends should avoid using Generic handles
323                // for kernels that need to be cloned
324                panic!("Cannot clone Generic kernel handles")
325            }
326        }
327    }
328}
329
330unsafe impl Send for KernelHandle {}
331unsafe impl Sync for KernelHandle {}
332
333/// Kernel metadata and compilation information
334#[derive(Debug, Clone)]
335pub struct KernelMetadata {
336    /// Compilation time in milliseconds
337    pub compile_time_ms: f64,
338
339    /// Compiled binary size in bytes
340    pub binary_size: usize,
341
342    /// Number of registers used per thread
343    pub registers_per_thread: Option<u32>,
344
345    /// Shared memory usage in bytes
346    pub shared_memory_usage: Option<usize>,
347
348    /// Maximum workgroup size
349    pub max_workgroup_size: Option<(u32, u32, u32)>,
350
351    /// Compiler version
352    pub compiler_version: String,
353
354    /// Compilation warnings
355    pub warnings: Vec<String>,
356
357    /// Performance hints from compiler
358    pub performance_hints: Vec<String>,
359}
360
361impl Default for KernelMetadata {
362    fn default() -> Self {
363        Self {
364            compile_time_ms: 0.0,
365            binary_size: 0,
366            registers_per_thread: None,
367            shared_memory_usage: None,
368            max_workgroup_size: None,
369            compiler_version: "Unknown".to_string(),
370            warnings: Vec::new(),
371            performance_hints: Vec::new(),
372        }
373    }
374}
375
376/// Kernel launch configuration
377#[derive(Debug, Clone)]
378pub struct KernelLaunchConfig {
379    /// Workgroup size (local work size)
380    pub workgroup_size: (u32, u32, u32),
381
382    /// Number of workgroups (global work size / workgroup size)
383    pub workgroup_count: (u32, u32, u32),
384
385    /// Shared memory size in bytes
386    pub shared_memory_size: Option<usize>,
387
388    /// Stream/queue for asynchronous execution
389    pub stream_id: Option<usize>,
390}
391
392impl KernelLaunchConfig {
393    /// Create a 1D launch configuration
394    pub fn linear(global_size: u32, workgroup_size: Option<u32>) -> Self {
395        let wg_size = workgroup_size.unwrap_or(256);
396        let wg_count = global_size.div_ceil(wg_size);
397
398        Self {
399            workgroup_size: (wg_size, 1, 1),
400            workgroup_count: (wg_count, 1, 1),
401            shared_memory_size: None,
402            stream_id: None,
403        }
404    }
405
406    /// Create a 2D launch configuration
407    pub fn grid_2d(global_size: (u32, u32), workgroup_size: Option<(u32, u32)>) -> Self {
408        let wg_size = workgroup_size.unwrap_or((16, 16));
409        let wg_count = (
410            global_size.0.div_ceil(wg_size.0),
411            global_size.1.div_ceil(wg_size.1),
412        );
413
414        Self {
415            workgroup_size: (wg_size.0, wg_size.1, 1),
416            workgroup_count: (wg_count.0, wg_count.1, 1),
417            shared_memory_size: None,
418            stream_id: None,
419        }
420    }
421
422    /// Create a 3D launch configuration
423    pub fn grid_3d(global_size: (u32, u32, u32), workgroup_size: Option<(u32, u32, u32)>) -> Self {
424        let wg_size = workgroup_size.unwrap_or((8, 8, 8));
425        let wg_count = (
426            global_size.0.div_ceil(wg_size.0),
427            global_size.1.div_ceil(wg_size.1),
428            global_size.2.div_ceil(wg_size.2),
429        );
430
431        Self {
432            workgroup_size: wg_size,
433            workgroup_count: wg_count,
434            shared_memory_size: None,
435            stream_id: None,
436        }
437    }
438
439    /// Set shared memory size
440    pub fn with_shared_memory(mut self, size: usize) -> Self {
441        self.shared_memory_size = Some(size);
442        self
443    }
444
445    /// Set execution stream
446    pub fn with_stream(mut self, stream_id: usize) -> Self {
447        self.stream_id = Some(stream_id);
448        self
449    }
450
451    /// Get total number of threads
452    pub fn total_threads(&self) -> u64 {
453        (self.workgroup_size.0 as u64)
454            * (self.workgroup_size.1 as u64)
455            * (self.workgroup_size.2 as u64)
456            * (self.workgroup_count.0 as u64)
457            * (self.workgroup_count.1 as u64)
458            * (self.workgroup_count.2 as u64)
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use crate::device::{Device, DeviceInfo};
466    use torsh_core::{device::DeviceType, dtype::DType};
467
468    fn create_test_device() -> Device {
469        let info = DeviceInfo::default();
470        Device::new(0, DeviceType::Cpu, "Test CPU".to_string(), info)
471    }
472
473    #[test]
474    fn test_kernel_descriptor_creation() {
475        let source = KernelSource::Source {
476            code: "void main() {}".to_string(),
477            language: KernelLanguage::Hlsl,
478        };
479
480        let desc = KernelDescriptor::new("test_kernel".to_string(), source);
481
482        assert_eq!(desc.name, "test_kernel");
483        assert!(desc.compile_options.is_empty());
484        assert!(desc.parameters.is_empty());
485        assert_eq!(desc.workgroup_size_hint, None);
486        assert!(desc.cache);
487    }
488
489    #[test]
490    fn test_kernel_descriptor_builder() {
491        let source = KernelSource::Source {
492            code: "void main() {}".to_string(),
493            language: KernelLanguage::Cuda,
494        };
495
496        let param = KernelParameter::buffer("input".to_string(), DType::F32, true);
497
498        let desc = KernelDescriptor::new("complex_kernel".to_string(), source)
499            .with_compile_option("-O3".to_string())
500            .with_compile_option("--fast-math".to_string())
501            .with_parameter(param)
502            .with_workgroup_size_hint((256, 1, 1))
503            .without_cache();
504
505        assert_eq!(desc.name, "complex_kernel");
506        assert_eq!(desc.compile_options.len(), 2);
507        assert!(desc.compile_options.contains(&"-O3".to_string()));
508        assert!(desc.compile_options.contains(&"--fast-math".to_string()));
509        assert_eq!(desc.parameters.len(), 1);
510        assert_eq!(desc.workgroup_size_hint, Some((256, 1, 1)));
511        assert!(!desc.cache);
512    }
513
514    #[test]
515    fn test_kernel_source_variants() {
516        let source1 = KernelSource::Source {
517            code: "vertex main() {}".to_string(),
518            language: KernelLanguage::Metal,
519        };
520
521        let source2 = KernelSource::Bytecode {
522            data: vec![0x12, 0x34, 0x56, 0x78],
523            format: BytecodeFormat::SpirV,
524        };
525
526        let source3 = KernelSource::SpirV {
527            data: vec![0x07230203, 0x00010000],
528        };
529
530        let source4 = KernelSource::Binary {
531            data: vec![0xCA, 0xFE, 0xBA, 0xBE],
532            platform: "cuda".to_string(),
533        };
534
535        // Test that all variants can be created
536        match source1 {
537            KernelSource::Source { language, .. } => assert_eq!(language, KernelLanguage::Metal),
538            _ => panic!("Wrong variant"),
539        }
540
541        match source2 {
542            KernelSource::Bytecode { format, .. } => assert_eq!(format, BytecodeFormat::SpirV),
543            _ => panic!("Wrong variant"),
544        }
545
546        match source3 {
547            KernelSource::SpirV { .. } => {}
548            _ => panic!("Wrong variant"),
549        }
550
551        match source4 {
552            KernelSource::Binary { platform, .. } => assert_eq!(platform, "cuda"),
553            _ => panic!("Wrong variant"),
554        }
555    }
556
557    #[test]
558    fn test_kernel_language_variants() {
559        let languages = [
560            KernelLanguage::Wgsl,
561            KernelLanguage::Hlsl,
562            KernelLanguage::Glsl,
563            KernelLanguage::Metal,
564            KernelLanguage::Cuda,
565            KernelLanguage::OpenCl,
566            KernelLanguage::Custom("MyLang".to_string()),
567        ];
568
569        // Ensure all languages are distinct
570        for (i, lang1) in languages.iter().enumerate() {
571            for (j, lang2) in languages.iter().enumerate() {
572                if i != j {
573                    assert_ne!(lang1, lang2);
574                }
575            }
576        }
577    }
578
579    #[test]
580    fn test_bytecode_format_variants() {
581        let formats = [
582            BytecodeFormat::SpirV,
583            BytecodeFormat::Dxil,
584            BytecodeFormat::MetalAir,
585            BytecodeFormat::Ptx,
586            BytecodeFormat::Custom("MyFormat".to_string()),
587        ];
588
589        // Ensure all formats are distinct
590        for (i, format1) in formats.iter().enumerate() {
591            for (j, format2) in formats.iter().enumerate() {
592                if i != j {
593                    assert_ne!(format1, format2);
594                }
595            }
596        }
597    }
598
599    #[test]
600    fn test_kernel_parameter_creation() {
601        let buffer_param = KernelParameter::buffer("data".to_string(), DType::F32, false);
602        assert_eq!(buffer_param.name, "data");
603        assert!(!buffer_param.readonly);
604        assert_eq!(buffer_param.binding, None);
605        match buffer_param.param_type {
606            KernelParameterType::Buffer { dtype } => assert_eq!(dtype, DType::F32),
607            _ => panic!("Wrong parameter type"),
608        }
609
610        let uniform_param = KernelParameter::uniform("scale".to_string(), DType::F32);
611        assert_eq!(uniform_param.name, "scale");
612        assert!(uniform_param.readonly);
613        match uniform_param.param_type {
614            KernelParameterType::Uniform { dtype } => assert_eq!(dtype, DType::F32),
615            _ => panic!("Wrong parameter type"),
616        }
617
618        let bound_param = buffer_param.with_binding(0);
619        assert_eq!(bound_param.binding, Some(0));
620    }
621
622    #[test]
623    fn test_kernel_parameter_types() {
624        let buffer_type = KernelParameterType::Buffer { dtype: DType::I32 };
625        let uniform_type = KernelParameterType::Uniform { dtype: DType::F64 };
626        let texture_type = KernelParameterType::Texture {
627            dimensions: 2,
628            dtype: DType::F32,
629        };
630        let sampler_type = KernelParameterType::Sampler;
631        let scalar_type = KernelParameterType::Scalar { dtype: DType::U8 };
632
633        // Test that different types are distinct
634        assert_ne!(
635            std::mem::discriminant(&buffer_type),
636            std::mem::discriminant(&uniform_type)
637        );
638        assert_ne!(
639            std::mem::discriminant(&uniform_type),
640            std::mem::discriminant(&texture_type)
641        );
642        assert_ne!(
643            std::mem::discriminant(&texture_type),
644            std::mem::discriminant(&sampler_type)
645        );
646        assert_ne!(
647            std::mem::discriminant(&sampler_type),
648            std::mem::discriminant(&scalar_type)
649        );
650    }
651
652    #[test]
653    fn test_kernel_handle_cpu() {
654        let handle = KernelHandle::Cpu {
655            function: std::ptr::null(),
656        };
657
658        match handle {
659            KernelHandle::Cpu { function } => assert!(function.is_null()),
660            _ => panic!("Wrong handle type"),
661        }
662    }
663
664    #[test]
665    fn test_kernel_metadata_default() {
666        let metadata = KernelMetadata::default();
667
668        assert_eq!(metadata.compile_time_ms, 0.0);
669        assert_eq!(metadata.binary_size, 0);
670        assert_eq!(metadata.registers_per_thread, None);
671        assert_eq!(metadata.shared_memory_usage, None);
672        assert_eq!(metadata.max_workgroup_size, None);
673        assert_eq!(metadata.compiler_version, "Unknown");
674        assert!(metadata.warnings.is_empty());
675        assert!(metadata.performance_hints.is_empty());
676    }
677
678    #[test]
679    fn test_kernel_creation() {
680        let device = create_test_device();
681        let source = KernelSource::Source {
682            code: "void main() {}".to_string(),
683            language: KernelLanguage::Hlsl,
684        };
685        let desc = KernelDescriptor::new("test".to_string(), source);
686        let handle = KernelHandle::Cpu {
687            function: std::ptr::null(),
688        };
689        let metadata = KernelMetadata::default();
690
691        let kernel = Kernel::new(
692            1,
693            device.clone(),
694            "test_kernel".to_string(),
695            desc,
696            handle,
697            metadata,
698        );
699
700        assert_eq!(kernel.id(), 1);
701        assert_eq!(kernel.name(), "test_kernel");
702        assert_eq!(kernel.device().id(), device.id());
703    }
704
705    #[test]
706    fn test_kernel_launch_config_linear() {
707        let config = KernelLaunchConfig::linear(1000, Some(64));
708
709        assert_eq!(config.workgroup_size, (64, 1, 1));
710        assert_eq!(config.workgroup_count, (16, 1, 1)); // ceil(1000/64) = 16
711        assert_eq!(config.shared_memory_size, None);
712        assert_eq!(config.stream_id, None);
713        assert_eq!(config.total_threads(), 64 * 16); // 1024 total threads
714
715        let config_default = KernelLaunchConfig::linear(1000, None);
716        assert_eq!(config_default.workgroup_size, (256, 1, 1));
717        assert_eq!(config_default.workgroup_count, (4, 1, 1)); // ceil(1000/256) = 4
718    }
719
720    #[test]
721    fn test_kernel_launch_config_2d() {
722        let config = KernelLaunchConfig::grid_2d((100, 50), Some((10, 5)));
723
724        assert_eq!(config.workgroup_size, (10, 5, 1));
725        assert_eq!(config.workgroup_count, (10, 10, 1)); // ceil(100/10), ceil(50/5)
726        assert_eq!(config.total_threads(), 10 * 5 * 10 * 10); // 5000 total threads
727
728        let config_default = KernelLaunchConfig::grid_2d((100, 50), None);
729        assert_eq!(config_default.workgroup_size, (16, 16, 1));
730        assert_eq!(config_default.workgroup_count, (7, 4, 1)); // ceil(100/16), ceil(50/16)
731    }
732
733    #[test]
734    fn test_kernel_launch_config_3d() {
735        let config = KernelLaunchConfig::grid_3d((64, 32, 16), Some((8, 4, 2)));
736
737        assert_eq!(config.workgroup_size, (8, 4, 2));
738        assert_eq!(config.workgroup_count, (8, 8, 8)); // ceil(64/8), ceil(32/4), ceil(16/2)
739        assert_eq!(config.total_threads(), 8 * 4 * 2 * 8 * 8 * 8); // 32768 total threads
740
741        let config_default = KernelLaunchConfig::grid_3d((64, 32, 16), None);
742        assert_eq!(config_default.workgroup_size, (8, 8, 8));
743        assert_eq!(config_default.workgroup_count, (8, 4, 2)); // ceil(64/8), ceil(32/8), ceil(16/8)
744    }
745
746    #[test]
747    fn test_kernel_launch_config_builder() {
748        let config = KernelLaunchConfig::linear(1000, Some(128))
749            .with_shared_memory(4096)
750            .with_stream(1);
751
752        assert_eq!(config.shared_memory_size, Some(4096));
753        assert_eq!(config.stream_id, Some(1));
754    }
755}