Skip to main content

torsh_backend/cpu/
kernel.rs

1//! CPU Kernel Implementation
2
3use crate::cpu::optimized_kernels;
4use crate::error::BackendResult;
5use crate::kernel::{KernelHandle, KernelMetadata};
6use crate::{Buffer, Device, Kernel, KernelDescriptor, KernelLaunchConfig};
7use torsh_core::error::{Result, TorshError};
8
9#[cfg(not(feature = "std"))]
10use alloc::{boxed::Box, string::String, vec::Vec};
11
12/// CPU kernel function signature
13pub type CpuKernelFn = fn(&[&Buffer], &[u8], &KernelLaunchConfig) -> Result<()>;
14
15/// CPU kernel implementation
16pub struct CpuKernel {
17    name: String,
18}
19
20impl CpuKernel {
21    /// Create a new CPU kernel from a descriptor
22    pub fn new(descriptor: &KernelDescriptor) -> BackendResult<Self> {
23        Ok(Self {
24            name: descriptor.name.clone(),
25        })
26    }
27
28    /// Create a CPU kernel and return an abstract Kernel
29    pub fn new_kernel(device: Device, descriptor: &KernelDescriptor) -> BackendResult<Kernel> {
30        let _cpu_kernel = Self::new(descriptor)?;
31
32        let handle = KernelHandle::Generic {
33            handle: Box::new("CPU kernel placeholder".to_string()),
34        };
35
36        let metadata = KernelMetadata {
37            compile_time_ms: 1.0,
38            binary_size: 0,
39            registers_per_thread: None,
40            shared_memory_usage: None,
41            max_workgroup_size: Some((u32::MAX, 1, 1)),
42            compiler_version: "CPU Backend".to_string(),
43            warnings: Vec::new(),
44            performance_hints: vec!["Use SIMD for better performance".to_string()],
45        };
46
47        let kernel = Kernel::new(
48            0,
49            device,
50            descriptor.name.clone(),
51            descriptor.clone(),
52            handle,
53            metadata,
54        );
55
56        Ok(kernel)
57    }
58
59    /// Get the kernel name
60    pub fn name(&self) -> &str {
61        &self.name
62    }
63
64    /// Helper function to get CPU buffer as f32 slice
65    fn get_cpu_buffer_f32(buffer: &Buffer) -> Result<&[f32]> {
66        match &buffer.handle {
67            crate::buffer::BufferHandle::Cpu { ptr, size } => unsafe {
68                Ok(std::slice::from_raw_parts(*ptr as *const f32, size / 4))
69            },
70            _ => Err(TorshError::InvalidArgument(
71                "Buffer is not CPU buffer".to_string(),
72            )),
73        }
74    }
75
76    /// Helper function to get CPU buffer as mutable f32 slice
77    fn get_cpu_buffer_f32_mut(buffer: &Buffer) -> Result<&mut [f32]> {
78        match &buffer.handle {
79            crate::buffer::BufferHandle::Cpu { ptr, size } => unsafe {
80                Ok(std::slice::from_raw_parts_mut(*ptr as *mut f32, size / 4))
81            },
82            _ => Err(TorshError::InvalidArgument(
83                "Buffer is not CPU buffer".to_string(),
84            )),
85        }
86    }
87
88    /// Execute the kernel
89    pub async fn execute(
90        &self,
91        _buffers: &[&Buffer],
92        _uniform_data: &[u8],
93        _launch_config: &KernelLaunchConfig,
94    ) -> BackendResult<()> {
95        Err(TorshError::BackendError(
96            "CPU kernel execution not yet implemented".to_string(),
97        ))
98    }
99
100    /// Get kernel function by name
101    pub fn get_kernel_fn(descriptor: &KernelDescriptor) -> Result<CpuKernelFn> {
102        let kernel_fn: CpuKernelFn = match descriptor.name.as_str() {
103            "add" => Self::kernel_add,
104            "mul" => Self::kernel_mul,
105            "sub" => Self::kernel_sub,
106            "div" => Self::kernel_div,
107            "relu" => Self::kernel_relu,
108            "sigmoid" => Self::kernel_sigmoid,
109            "tanh" => Self::kernel_tanh,
110            "matmul" => Self::kernel_matmul,
111            "dot" => Self::kernel_dot,
112            "sum" => Self::kernel_sum,
113            "mean" => Self::kernel_mean,
114            _ => {
115                return Err(TorshError::InvalidArgument(format!(
116                    "Unsupported kernel: {}",
117                    descriptor.name
118                )))
119            }
120        };
121
122        Ok(kernel_fn)
123    }
124
125    // Built-in kernel implementations
126
127    /// Element-wise addition kernel
128    fn kernel_add(
129        buffers: &[&Buffer],
130        _uniform_data: &[u8],
131        _launch_config: &KernelLaunchConfig,
132    ) -> Result<()> {
133        if buffers.len() != 3 {
134            return Err(TorshError::InvalidArgument(
135                "Add kernel requires 3 buffers".to_string(),
136            ));
137        }
138
139        let a = Self::get_cpu_buffer_f32(buffers[0])?;
140        let b = Self::get_cpu_buffer_f32(buffers[1])?;
141        let result = Self::get_cpu_buffer_f32_mut(buffers[2])?;
142
143        if a.len() != b.len() || a.len() != result.len() {
144            return Err(TorshError::InvalidArgument(
145                "Buffer size mismatch".to_string(),
146            ));
147        }
148
149        // Use optimized parallel addition
150        optimized_kernels::parallel_ops::parallel_elementwise(a, b, result, |x, y| x + y);
151
152        Ok(())
153    }
154
155    /// Element-wise multiplication kernel
156    fn kernel_mul(
157        buffers: &[&Buffer],
158        _uniform_data: &[u8],
159        _launch_config: &KernelLaunchConfig,
160    ) -> Result<()> {
161        if buffers.len() != 3 {
162            return Err(TorshError::InvalidArgument(
163                "Mul kernel requires 3 buffers".to_string(),
164            ));
165        }
166
167        let a = Self::get_cpu_buffer_f32(buffers[0])?;
168        let b = Self::get_cpu_buffer_f32(buffers[1])?;
169        let result = Self::get_cpu_buffer_f32_mut(buffers[2])?;
170
171        if a.len() != b.len() || a.len() != result.len() {
172            return Err(TorshError::InvalidArgument(
173                "Buffer size mismatch".to_string(),
174            ));
175        }
176
177        // Use optimized parallel multiplication
178        optimized_kernels::parallel_ops::parallel_elementwise(a, b, result, |x, y| x * y);
179
180        Ok(())
181    }
182
183    /// Element-wise subtraction kernel
184    fn kernel_sub(
185        buffers: &[&Buffer],
186        _uniform_data: &[u8],
187        _launch_config: &KernelLaunchConfig,
188    ) -> Result<()> {
189        if buffers.len() != 3 {
190            return Err(TorshError::InvalidArgument(
191                "Sub kernel requires 3 buffers".to_string(),
192            ));
193        }
194
195        let a = Self::get_cpu_buffer_f32(buffers[0])?;
196        let b = Self::get_cpu_buffer_f32(buffers[1])?;
197        let result = Self::get_cpu_buffer_f32_mut(buffers[2])?;
198
199        if a.len() != b.len() || a.len() != result.len() {
200            return Err(TorshError::InvalidArgument(
201                "Buffer size mismatch".to_string(),
202            ));
203        }
204
205        // Use optimized parallel subtraction
206        optimized_kernels::parallel_ops::parallel_elementwise(a, b, result, |x, y| x - y);
207
208        Ok(())
209    }
210
211    /// Element-wise division kernel
212    fn kernel_div(
213        buffers: &[&Buffer],
214        _uniform_data: &[u8],
215        _launch_config: &KernelLaunchConfig,
216    ) -> Result<()> {
217        if buffers.len() != 3 {
218            return Err(TorshError::InvalidArgument(
219                "Div kernel requires 3 buffers".to_string(),
220            ));
221        }
222
223        let a = Self::get_cpu_buffer_f32(buffers[0])?;
224        let b = Self::get_cpu_buffer_f32(buffers[1])?;
225        let result = Self::get_cpu_buffer_f32_mut(buffers[2])?;
226
227        if a.len() != b.len() || a.len() != result.len() {
228            return Err(TorshError::InvalidArgument(
229                "Buffer size mismatch".to_string(),
230            ));
231        }
232
233        // Use optimized parallel division
234        optimized_kernels::parallel_ops::parallel_elementwise(a, b, result, |x, y| x / y);
235
236        Ok(())
237    }
238
239    /// ReLU activation kernel
240    fn kernel_relu(
241        buffers: &[&Buffer],
242        _uniform_data: &[u8],
243        _launch_config: &KernelLaunchConfig,
244    ) -> Result<()> {
245        if buffers.len() != 2 {
246            return Err(TorshError::InvalidArgument(
247                "ReLU kernel requires 2 buffers".to_string(),
248            ));
249        }
250
251        let input = Self::get_cpu_buffer_f32(buffers[0])?;
252        let output = Self::get_cpu_buffer_f32_mut(buffers[1])?;
253
254        if input.len() != output.len() {
255            return Err(TorshError::InvalidArgument(
256                "Buffer size mismatch".to_string(),
257            ));
258        }
259
260        // Use optimized parallel ReLU
261        optimized_kernels::parallel_ops::parallel_unary(input, output, |x| x.max(0.0));
262
263        Ok(())
264    }
265
266    /// Sigmoid activation kernel
267    fn kernel_sigmoid(
268        buffers: &[&Buffer],
269        _uniform_data: &[u8],
270        _launch_config: &KernelLaunchConfig,
271    ) -> Result<()> {
272        if buffers.len() != 2 {
273            return Err(TorshError::InvalidArgument(
274                "Sigmoid kernel requires 2 buffers".to_string(),
275            ));
276        }
277
278        let input = Self::get_cpu_buffer_f32(buffers[0])?;
279        let output = Self::get_cpu_buffer_f32_mut(buffers[1])?;
280
281        if input.len() != output.len() {
282            return Err(TorshError::InvalidArgument(
283                "Buffer size mismatch".to_string(),
284            ));
285        }
286
287        // Use optimized parallel Sigmoid
288        optimized_kernels::parallel_ops::parallel_unary(input, output, |x| {
289            1.0 / (1.0 + (-x).exp())
290        });
291
292        Ok(())
293    }
294
295    /// Tanh activation kernel
296    fn kernel_tanh(
297        buffers: &[&Buffer],
298        _uniform_data: &[u8],
299        _launch_config: &KernelLaunchConfig,
300    ) -> Result<()> {
301        if buffers.len() != 2 {
302            return Err(TorshError::InvalidArgument(
303                "Tanh kernel requires 2 buffers".to_string(),
304            ));
305        }
306
307        let input = Self::get_cpu_buffer_f32(buffers[0])?;
308        let output = Self::get_cpu_buffer_f32_mut(buffers[1])?;
309
310        if input.len() != output.len() {
311            return Err(TorshError::InvalidArgument(
312                "Buffer size mismatch".to_string(),
313            ));
314        }
315
316        // Use optimized parallel Tanh
317        optimized_kernels::parallel_ops::parallel_unary(input, output, |x| x.tanh());
318
319        Ok(())
320    }
321
322    /// Matrix multiplication kernel using BLAS
323    fn kernel_matmul(
324        buffers: &[&Buffer],
325        uniform_data: &[u8],
326        _launch_config: &KernelLaunchConfig,
327    ) -> Result<()> {
328        if buffers.len() != 3 {
329            return Err(TorshError::InvalidArgument(
330                "Matmul kernel requires 3 buffers".to_string(),
331            ));
332        }
333
334        let a = Self::get_cpu_buffer_f32(buffers[0])?;
335        let b = Self::get_cpu_buffer_f32(buffers[1])?;
336        let result = Self::get_cpu_buffer_f32_mut(buffers[2])?;
337
338        // Parse matrix dimensions from uniform data
339        // Expected format: [m: u32, n: u32, k: u32, transpose_a: u8, transpose_b: u8]
340        if uniform_data.len() < 14 {
341            return Err(TorshError::InvalidArgument(
342                "Insufficient uniform data for matmul (need m, n, k, transpose_a, transpose_b)"
343                    .to_string(),
344            ));
345        }
346
347        let m = u32::from_le_bytes([
348            uniform_data[0],
349            uniform_data[1],
350            uniform_data[2],
351            uniform_data[3],
352        ]) as usize;
353        let n = u32::from_le_bytes([
354            uniform_data[4],
355            uniform_data[5],
356            uniform_data[6],
357            uniform_data[7],
358        ]) as usize;
359        let k = u32::from_le_bytes([
360            uniform_data[8],
361            uniform_data[9],
362            uniform_data[10],
363            uniform_data[11],
364        ]) as usize;
365        let transpose_a = uniform_data[12] != 0;
366        let transpose_b = uniform_data[13] != 0;
367
368        // Use optimized matrix multiplication
369        optimized_kernels::optimized_matmul(a, b, result, m, n, k, transpose_a, transpose_b)
370    }
371
372    /// Dot product kernel using BLAS
373    fn kernel_dot(
374        buffers: &[&Buffer],
375        _uniform_data: &[u8],
376        _launch_config: &KernelLaunchConfig,
377    ) -> Result<()> {
378        if buffers.len() != 3 {
379            return Err(TorshError::InvalidArgument(
380                "Dot kernel requires 3 buffers".to_string(),
381            ));
382        }
383
384        let a = Self::get_cpu_buffer_f32(buffers[0])?;
385        let b = Self::get_cpu_buffer_f32(buffers[1])?;
386        let result = Self::get_cpu_buffer_f32_mut(buffers[2])?;
387
388        if result.len() != 1 {
389            return Err(TorshError::InvalidArgument(
390                "Output buffer should have size 1 for dot product".to_string(),
391            ));
392        }
393
394        // Use optimized dot product
395        result[0] = optimized_kernels::optimized_dot(a, b)?;
396
397        Ok(())
398    }
399
400    /// Sum reduction kernel
401    fn kernel_sum(
402        buffers: &[&Buffer],
403        _uniform_data: &[u8],
404        _launch_config: &KernelLaunchConfig,
405    ) -> Result<()> {
406        if buffers.len() != 2 {
407            return Err(TorshError::InvalidArgument(
408                "Sum kernel requires 2 buffers".to_string(),
409            ));
410        }
411
412        let input = Self::get_cpu_buffer_f32(buffers[0])?;
413        let output = Self::get_cpu_buffer_f32_mut(buffers[1])?;
414
415        if output.len() != 1 {
416            return Err(TorshError::InvalidArgument(
417                "Output buffer should have size 1 for sum reduction".to_string(),
418            ));
419        }
420
421        // Use optimized parallel sum
422        output[0] = optimized_kernels::parallel_ops::parallel_sum(input);
423
424        Ok(())
425    }
426
427    /// Mean reduction kernel
428    fn kernel_mean(
429        buffers: &[&Buffer],
430        _uniform_data: &[u8],
431        _launch_config: &KernelLaunchConfig,
432    ) -> Result<()> {
433        if buffers.len() != 2 {
434            return Err(TorshError::InvalidArgument(
435                "Mean kernel requires 2 buffers".to_string(),
436            ));
437        }
438
439        let input = Self::get_cpu_buffer_f32(buffers[0])?;
440        let output = Self::get_cpu_buffer_f32_mut(buffers[1])?;
441
442        if output.len() != 1 {
443            return Err(TorshError::InvalidArgument(
444                "Output buffer should have size 1 for mean reduction".to_string(),
445            ));
446        }
447
448        // Use optimized parallel mean
449        output[0] = optimized_kernels::parallel_ops::parallel_mean(input);
450
451        Ok(())
452    }
453}
454
455// Extension trait for Kernel to work with CPU kernels
456pub trait KernelCpuExt {
457    fn is_cpu(&self) -> bool;
458}
459
460impl KernelCpuExt for Kernel {
461    fn is_cpu(&self) -> bool {
462        matches!(self.handle, KernelHandle::Generic { .. })
463    }
464}
465
466/// CPU kernel executor for managing kernel execution
467pub struct CpuKernelExecutor;
468
469impl CpuKernelExecutor {
470    /// Create a new CPU kernel executor
471    pub fn new() -> Self {
472        Self
473    }
474}
475
476impl Default for CpuKernelExecutor {
477    fn default() -> Self {
478        Self::new()
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485    // Removed unused imports
486    use crate::kernel::{KernelLanguage, KernelSource};
487    // Removed unused import
488
489    #[test]
490    fn test_cpu_kernel_creation() {
491        let descriptor = KernelDescriptor::new(
492            "add".to_string(),
493            KernelSource::Source {
494                code: "// CPU add kernel".to_string(),
495                language: KernelLanguage::Custom("CPU".to_string()),
496            },
497        );
498
499        let kernel = CpuKernel::new(&descriptor).unwrap();
500        assert_eq!(kernel.name(), "add");
501    }
502}