1use crate::cpu::optimized_kernels;
4use crate::error::BackendResult;
5use crate::kernel::{KernelHandle, KernelMetadata};
6use crate::{Buffer, Device, Kernel, KernelDescriptor, KernelLaunchConfig};
7use torsh_core::error::{Result, TorshError};
8
9#[cfg(not(feature = "std"))]
10use alloc::{boxed::Box, string::String, vec::Vec};
11
12pub type CpuKernelFn = fn(&[&Buffer], &[u8], &KernelLaunchConfig) -> Result<()>;
14
15pub struct CpuKernel {
17 name: String,
18}
19
20impl CpuKernel {
21 pub fn new(descriptor: &KernelDescriptor) -> BackendResult<Self> {
23 Ok(Self {
24 name: descriptor.name.clone(),
25 })
26 }
27
28 pub fn new_kernel(device: Device, descriptor: &KernelDescriptor) -> BackendResult<Kernel> {
30 let _cpu_kernel = Self::new(descriptor)?;
31
32 let handle = KernelHandle::Generic {
33 handle: Box::new("CPU kernel placeholder".to_string()),
34 };
35
36 let metadata = KernelMetadata {
37 compile_time_ms: 1.0,
38 binary_size: 0,
39 registers_per_thread: None,
40 shared_memory_usage: None,
41 max_workgroup_size: Some((u32::MAX, 1, 1)),
42 compiler_version: "CPU Backend".to_string(),
43 warnings: Vec::new(),
44 performance_hints: vec!["Use SIMD for better performance".to_string()],
45 };
46
47 let kernel = Kernel::new(
48 0,
49 device,
50 descriptor.name.clone(),
51 descriptor.clone(),
52 handle,
53 metadata,
54 );
55
56 Ok(kernel)
57 }
58
59 pub fn name(&self) -> &str {
61 &self.name
62 }
63
64 fn get_cpu_buffer_f32(buffer: &Buffer) -> Result<&[f32]> {
66 match &buffer.handle {
67 crate::buffer::BufferHandle::Cpu { ptr, size } => unsafe {
68 Ok(std::slice::from_raw_parts(*ptr as *const f32, size / 4))
69 },
70 _ => Err(TorshError::InvalidArgument(
71 "Buffer is not CPU buffer".to_string(),
72 )),
73 }
74 }
75
76 fn get_cpu_buffer_f32_mut(buffer: &Buffer) -> Result<&mut [f32]> {
78 match &buffer.handle {
79 crate::buffer::BufferHandle::Cpu { ptr, size } => unsafe {
80 Ok(std::slice::from_raw_parts_mut(*ptr as *mut f32, size / 4))
81 },
82 _ => Err(TorshError::InvalidArgument(
83 "Buffer is not CPU buffer".to_string(),
84 )),
85 }
86 }
87
88 pub async fn execute(
90 &self,
91 _buffers: &[&Buffer],
92 _uniform_data: &[u8],
93 _launch_config: &KernelLaunchConfig,
94 ) -> BackendResult<()> {
95 Err(TorshError::BackendError(
96 "CPU kernel execution not yet implemented".to_string(),
97 ))
98 }
99
100 pub fn get_kernel_fn(descriptor: &KernelDescriptor) -> Result<CpuKernelFn> {
102 let kernel_fn: CpuKernelFn = match descriptor.name.as_str() {
103 "add" => Self::kernel_add,
104 "mul" => Self::kernel_mul,
105 "sub" => Self::kernel_sub,
106 "div" => Self::kernel_div,
107 "relu" => Self::kernel_relu,
108 "sigmoid" => Self::kernel_sigmoid,
109 "tanh" => Self::kernel_tanh,
110 "matmul" => Self::kernel_matmul,
111 "dot" => Self::kernel_dot,
112 "sum" => Self::kernel_sum,
113 "mean" => Self::kernel_mean,
114 _ => {
115 return Err(TorshError::InvalidArgument(format!(
116 "Unsupported kernel: {}",
117 descriptor.name
118 )))
119 }
120 };
121
122 Ok(kernel_fn)
123 }
124
125 fn kernel_add(
129 buffers: &[&Buffer],
130 _uniform_data: &[u8],
131 _launch_config: &KernelLaunchConfig,
132 ) -> Result<()> {
133 if buffers.len() != 3 {
134 return Err(TorshError::InvalidArgument(
135 "Add kernel requires 3 buffers".to_string(),
136 ));
137 }
138
139 let a = Self::get_cpu_buffer_f32(buffers[0])?;
140 let b = Self::get_cpu_buffer_f32(buffers[1])?;
141 let result = Self::get_cpu_buffer_f32_mut(buffers[2])?;
142
143 if a.len() != b.len() || a.len() != result.len() {
144 return Err(TorshError::InvalidArgument(
145 "Buffer size mismatch".to_string(),
146 ));
147 }
148
149 optimized_kernels::parallel_ops::parallel_elementwise(a, b, result, |x, y| x + y);
151
152 Ok(())
153 }
154
155 fn kernel_mul(
157 buffers: &[&Buffer],
158 _uniform_data: &[u8],
159 _launch_config: &KernelLaunchConfig,
160 ) -> Result<()> {
161 if buffers.len() != 3 {
162 return Err(TorshError::InvalidArgument(
163 "Mul kernel requires 3 buffers".to_string(),
164 ));
165 }
166
167 let a = Self::get_cpu_buffer_f32(buffers[0])?;
168 let b = Self::get_cpu_buffer_f32(buffers[1])?;
169 let result = Self::get_cpu_buffer_f32_mut(buffers[2])?;
170
171 if a.len() != b.len() || a.len() != result.len() {
172 return Err(TorshError::InvalidArgument(
173 "Buffer size mismatch".to_string(),
174 ));
175 }
176
177 optimized_kernels::parallel_ops::parallel_elementwise(a, b, result, |x, y| x * y);
179
180 Ok(())
181 }
182
183 fn kernel_sub(
185 buffers: &[&Buffer],
186 _uniform_data: &[u8],
187 _launch_config: &KernelLaunchConfig,
188 ) -> Result<()> {
189 if buffers.len() != 3 {
190 return Err(TorshError::InvalidArgument(
191 "Sub kernel requires 3 buffers".to_string(),
192 ));
193 }
194
195 let a = Self::get_cpu_buffer_f32(buffers[0])?;
196 let b = Self::get_cpu_buffer_f32(buffers[1])?;
197 let result = Self::get_cpu_buffer_f32_mut(buffers[2])?;
198
199 if a.len() != b.len() || a.len() != result.len() {
200 return Err(TorshError::InvalidArgument(
201 "Buffer size mismatch".to_string(),
202 ));
203 }
204
205 optimized_kernels::parallel_ops::parallel_elementwise(a, b, result, |x, y| x - y);
207
208 Ok(())
209 }
210
211 fn kernel_div(
213 buffers: &[&Buffer],
214 _uniform_data: &[u8],
215 _launch_config: &KernelLaunchConfig,
216 ) -> Result<()> {
217 if buffers.len() != 3 {
218 return Err(TorshError::InvalidArgument(
219 "Div kernel requires 3 buffers".to_string(),
220 ));
221 }
222
223 let a = Self::get_cpu_buffer_f32(buffers[0])?;
224 let b = Self::get_cpu_buffer_f32(buffers[1])?;
225 let result = Self::get_cpu_buffer_f32_mut(buffers[2])?;
226
227 if a.len() != b.len() || a.len() != result.len() {
228 return Err(TorshError::InvalidArgument(
229 "Buffer size mismatch".to_string(),
230 ));
231 }
232
233 optimized_kernels::parallel_ops::parallel_elementwise(a, b, result, |x, y| x / y);
235
236 Ok(())
237 }
238
239 fn kernel_relu(
241 buffers: &[&Buffer],
242 _uniform_data: &[u8],
243 _launch_config: &KernelLaunchConfig,
244 ) -> Result<()> {
245 if buffers.len() != 2 {
246 return Err(TorshError::InvalidArgument(
247 "ReLU kernel requires 2 buffers".to_string(),
248 ));
249 }
250
251 let input = Self::get_cpu_buffer_f32(buffers[0])?;
252 let output = Self::get_cpu_buffer_f32_mut(buffers[1])?;
253
254 if input.len() != output.len() {
255 return Err(TorshError::InvalidArgument(
256 "Buffer size mismatch".to_string(),
257 ));
258 }
259
260 optimized_kernels::parallel_ops::parallel_unary(input, output, |x| x.max(0.0));
262
263 Ok(())
264 }
265
266 fn kernel_sigmoid(
268 buffers: &[&Buffer],
269 _uniform_data: &[u8],
270 _launch_config: &KernelLaunchConfig,
271 ) -> Result<()> {
272 if buffers.len() != 2 {
273 return Err(TorshError::InvalidArgument(
274 "Sigmoid kernel requires 2 buffers".to_string(),
275 ));
276 }
277
278 let input = Self::get_cpu_buffer_f32(buffers[0])?;
279 let output = Self::get_cpu_buffer_f32_mut(buffers[1])?;
280
281 if input.len() != output.len() {
282 return Err(TorshError::InvalidArgument(
283 "Buffer size mismatch".to_string(),
284 ));
285 }
286
287 optimized_kernels::parallel_ops::parallel_unary(input, output, |x| {
289 1.0 / (1.0 + (-x).exp())
290 });
291
292 Ok(())
293 }
294
295 fn kernel_tanh(
297 buffers: &[&Buffer],
298 _uniform_data: &[u8],
299 _launch_config: &KernelLaunchConfig,
300 ) -> Result<()> {
301 if buffers.len() != 2 {
302 return Err(TorshError::InvalidArgument(
303 "Tanh kernel requires 2 buffers".to_string(),
304 ));
305 }
306
307 let input = Self::get_cpu_buffer_f32(buffers[0])?;
308 let output = Self::get_cpu_buffer_f32_mut(buffers[1])?;
309
310 if input.len() != output.len() {
311 return Err(TorshError::InvalidArgument(
312 "Buffer size mismatch".to_string(),
313 ));
314 }
315
316 optimized_kernels::parallel_ops::parallel_unary(input, output, |x| x.tanh());
318
319 Ok(())
320 }
321
322 fn kernel_matmul(
324 buffers: &[&Buffer],
325 uniform_data: &[u8],
326 _launch_config: &KernelLaunchConfig,
327 ) -> Result<()> {
328 if buffers.len() != 3 {
329 return Err(TorshError::InvalidArgument(
330 "Matmul kernel requires 3 buffers".to_string(),
331 ));
332 }
333
334 let a = Self::get_cpu_buffer_f32(buffers[0])?;
335 let b = Self::get_cpu_buffer_f32(buffers[1])?;
336 let result = Self::get_cpu_buffer_f32_mut(buffers[2])?;
337
338 if uniform_data.len() < 14 {
341 return Err(TorshError::InvalidArgument(
342 "Insufficient uniform data for matmul (need m, n, k, transpose_a, transpose_b)"
343 .to_string(),
344 ));
345 }
346
347 let m = u32::from_le_bytes([
348 uniform_data[0],
349 uniform_data[1],
350 uniform_data[2],
351 uniform_data[3],
352 ]) as usize;
353 let n = u32::from_le_bytes([
354 uniform_data[4],
355 uniform_data[5],
356 uniform_data[6],
357 uniform_data[7],
358 ]) as usize;
359 let k = u32::from_le_bytes([
360 uniform_data[8],
361 uniform_data[9],
362 uniform_data[10],
363 uniform_data[11],
364 ]) as usize;
365 let transpose_a = uniform_data[12] != 0;
366 let transpose_b = uniform_data[13] != 0;
367
368 optimized_kernels::optimized_matmul(a, b, result, m, n, k, transpose_a, transpose_b)
370 }
371
372 fn kernel_dot(
374 buffers: &[&Buffer],
375 _uniform_data: &[u8],
376 _launch_config: &KernelLaunchConfig,
377 ) -> Result<()> {
378 if buffers.len() != 3 {
379 return Err(TorshError::InvalidArgument(
380 "Dot kernel requires 3 buffers".to_string(),
381 ));
382 }
383
384 let a = Self::get_cpu_buffer_f32(buffers[0])?;
385 let b = Self::get_cpu_buffer_f32(buffers[1])?;
386 let result = Self::get_cpu_buffer_f32_mut(buffers[2])?;
387
388 if result.len() != 1 {
389 return Err(TorshError::InvalidArgument(
390 "Output buffer should have size 1 for dot product".to_string(),
391 ));
392 }
393
394 result[0] = optimized_kernels::optimized_dot(a, b)?;
396
397 Ok(())
398 }
399
400 fn kernel_sum(
402 buffers: &[&Buffer],
403 _uniform_data: &[u8],
404 _launch_config: &KernelLaunchConfig,
405 ) -> Result<()> {
406 if buffers.len() != 2 {
407 return Err(TorshError::InvalidArgument(
408 "Sum kernel requires 2 buffers".to_string(),
409 ));
410 }
411
412 let input = Self::get_cpu_buffer_f32(buffers[0])?;
413 let output = Self::get_cpu_buffer_f32_mut(buffers[1])?;
414
415 if output.len() != 1 {
416 return Err(TorshError::InvalidArgument(
417 "Output buffer should have size 1 for sum reduction".to_string(),
418 ));
419 }
420
421 output[0] = optimized_kernels::parallel_ops::parallel_sum(input);
423
424 Ok(())
425 }
426
427 fn kernel_mean(
429 buffers: &[&Buffer],
430 _uniform_data: &[u8],
431 _launch_config: &KernelLaunchConfig,
432 ) -> Result<()> {
433 if buffers.len() != 2 {
434 return Err(TorshError::InvalidArgument(
435 "Mean kernel requires 2 buffers".to_string(),
436 ));
437 }
438
439 let input = Self::get_cpu_buffer_f32(buffers[0])?;
440 let output = Self::get_cpu_buffer_f32_mut(buffers[1])?;
441
442 if output.len() != 1 {
443 return Err(TorshError::InvalidArgument(
444 "Output buffer should have size 1 for mean reduction".to_string(),
445 ));
446 }
447
448 output[0] = optimized_kernels::parallel_ops::parallel_mean(input);
450
451 Ok(())
452 }
453}
454
455pub trait KernelCpuExt {
457 fn is_cpu(&self) -> bool;
458}
459
460impl KernelCpuExt for Kernel {
461 fn is_cpu(&self) -> bool {
462 matches!(self.handle, KernelHandle::Generic { .. })
463 }
464}
465
466pub struct CpuKernelExecutor;
468
469impl CpuKernelExecutor {
470 pub fn new() -> Self {
472 Self
473 }
474}
475
476impl Default for CpuKernelExecutor {
477 fn default() -> Self {
478 Self::new()
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485 use crate::kernel::{KernelLanguage, KernelSource};
487 #[test]
490 fn test_cpu_kernel_creation() {
491 let descriptor = KernelDescriptor::new(
492 "add".to_string(),
493 KernelSource::Source {
494 code: "// CPU add kernel".to_string(),
495 language: KernelLanguage::Custom("CPU".to_string()),
496 },
497 );
498
499 let kernel = CpuKernel::new(&descriptor).unwrap();
500 assert_eq!(kernel.name(), "add");
501 }
502}