1use std::sync::Arc;
12
13use svod_device::Result;
14use svod_device::device::{Compiler, Device, Program, ProgramSpec, Renderer, RuntimeFactory};
15use svod_device::registry::DeviceRegistry;
16use svod_dtype::DeviceSpec;
17use svod_ir::UOp;
18
19use crate::LlvmKernel;
20use crate::clang::ClangKernel;
21use crate::dispatch::KernelCif;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
25pub enum CpuBackend {
26 #[default]
29 Clang,
30 Llvm,
33 #[cfg(feature = "mlir")]
36 Mlir,
37}
38
39impl CpuBackend {
40 pub fn from_env() -> Self {
42 match std::env::var("SVOD_CPU_BACKEND").as_deref() {
43 Ok("clang") | Ok("CLANG") => CpuBackend::Clang,
44 Ok("llvm") | Ok("LLVM") => CpuBackend::Llvm,
45 #[cfg(feature = "mlir")]
46 Ok("mlir") | Ok("MLIR") => CpuBackend::Mlir,
47 _ => CpuBackend::default(),
48 }
49 }
50}
51
52unsafe fn execute_parallel(
68 cif: &KernelCif,
69 fn_ptr: *const (),
70 buffers: &[*mut u8],
71 vals: &[i64],
72 var_names: &[String],
73 core_count: usize,
74) -> Result<()> {
75 use rayon::prelude::*;
76
77 let core_id_idx = var_names.iter().position(|n| n == "core_id").ok_or_else(|| svod_device::Error::Runtime {
78 message: "parallel CPU launch requires core_id runtime variable".to_string(),
79 })?;
80 let fn_ptr_usize = fn_ptr as usize;
81
82 let buf_ptr = buffers.as_ptr() as usize;
86 let buf_len = buffers.len();
87
88 if rayon::current_thread_index().is_some() {
91 for core_id in 0..core_count {
92 let bufs = unsafe { std::slice::from_raw_parts(buf_ptr as *const *mut u8, buf_len) };
93 unsafe {
94 cif.dispatch(fn_ptr_usize as *const (), bufs, vals, Some((core_id_idx, core_id)));
95 }
96 }
97 return Ok(());
98 }
99
100 (0..core_count).into_par_iter().for_each(|core_id| {
101 let bufs = unsafe { std::slice::from_raw_parts(buf_ptr as *const *mut u8, buf_len) };
102 unsafe {
103 cif.dispatch(fn_ptr_usize as *const (), bufs, vals, Some((core_id_idx, core_id)));
104 }
105 });
106
107 Ok(())
108}
109
110unsafe fn execute_kernel(
116 cif: &KernelCif,
117 fn_ptr: *const (),
118 buffers: &[*mut u8],
119 vals: &[i64],
120 var_names: &[String],
121 global_size: Option<[usize; 3]>,
122) -> Result<()> {
123 let core_count = global_size.map(|[tc, _, _]| tc).filter(|&tc| tc > 1);
124 if let Some(count) = core_count {
125 unsafe { execute_parallel(cif, fn_ptr, buffers, vals, var_names, count) }
126 } else {
127 unsafe { cif.dispatch(fn_ptr, buffers, vals, None) };
128 Ok(())
129 }
130}
131
132struct ClangProgram {
138 kernel: ClangKernel,
139}
140
141impl Program for ClangProgram {
142 unsafe fn execute(
143 &self,
144 buffers: &[*mut u8],
145 vals: &[i64],
146 global_size: Option<[usize; 3]>,
147 _local_size: Option<[usize; 3]>,
148 ) -> Result<()> {
149 unsafe {
150 execute_kernel(self.kernel.cif(), self.kernel.fn_ptr(), buffers, vals, self.kernel.var_names(), global_size)
151 }
152 }
153
154 fn name(&self) -> &str {
155 self.kernel.name()
156 }
157}
158
159struct ClangRendererWrapper {
161 device: DeviceSpec,
162}
163
164impl Renderer for ClangRendererWrapper {
165 fn render(&self, ast: &Arc<UOp>, name: Option<&str>) -> Result<ProgramSpec> {
166 let rendered = svod_codegen::c::render(ast, name.or(Some("kernel")))
167 .map_err(|e| svod_device::Error::Runtime { message: format!("C rendering failed: {}", e) })?;
168
169 let mut spec = ProgramSpec::new(rendered.name.clone(), rendered.code.clone(), self.device.clone(), ast.clone());
170
171 spec.set_var_names(rendered.var_names.clone());
172 spec.apply_derived_metadata_from_ast();
173 if spec.buf_count == 0 {
174 spec.buf_count = rendered.buffer_args.len();
175 }
176
177 Ok(spec)
178 }
179
180 fn device(&self) -> &DeviceSpec {
181 &self.device
182 }
183}
184
185struct ClangCompiler;
187
188impl Compiler for ClangCompiler {
189 fn compile(&self, spec: &ProgramSpec) -> Result<svod_device::device::CompiledSpec> {
190 let mut compiled = svod_device::device::CompiledSpec::from_source(
191 spec.name.clone(),
192 spec.src.clone(),
193 spec.ast.clone(),
194 spec.buf_count,
195 );
196 compiled.var_names = spec.var_names.clone();
197 compiled.global_size = spec.global_size.clone();
198 compiled.local_size = spec.local_size.clone();
199 Ok(compiled)
200 }
201
202 fn cache_key(&self) -> &'static str {
203 "clang"
204 }
205}
206
207fn create_clang_program(spec: &svod_device::device::CompiledSpec) -> Result<Box<dyn Program>> {
209 let src = spec.src.as_ref().ok_or_else(|| svod_device::Error::Runtime {
210 message: "Clang backend requires source code in CompiledSpec".to_string(),
211 })?;
212
213 let kernel = ClangKernel::compile(src, &spec.name, spec.var_names.clone(), spec.buf_count)
214 .map_err(|e| svod_device::Error::Runtime { message: format!("Clang compilation failed: {}", e) })?;
215
216 Ok(Box::new(ClangProgram { kernel }))
217}
218
219struct LlvmProgram {
225 kernel: LlvmKernel,
226}
227
228impl Program for LlvmProgram {
229 unsafe fn execute(
230 &self,
231 buffers: &[*mut u8],
232 vals: &[i64],
233 global_size: Option<[usize; 3]>,
234 _local_size: Option<[usize; 3]>,
235 ) -> Result<()> {
236 unsafe {
237 execute_kernel(self.kernel.cif(), self.kernel.fn_ptr(), buffers, vals, self.kernel.var_names(), global_size)
238 }
239 }
240
241 fn name(&self) -> &str {
242 self.kernel.name()
243 }
244}
245
246struct LlvmCompiler;
248
249impl Compiler for LlvmCompiler {
250 fn compile(&self, spec: &svod_device::device::ProgramSpec) -> Result<svod_device::device::CompiledSpec> {
251 let mut compiled = svod_device::device::CompiledSpec::from_source(
252 spec.name.clone(),
253 spec.src.clone(),
254 spec.ast.clone(),
255 spec.buf_count,
256 );
257 compiled.var_names = spec.var_names.clone();
258 compiled.global_size = spec.global_size.clone();
259 compiled.local_size = spec.local_size.clone();
260 Ok(compiled)
261 }
262
263 fn cache_key(&self) -> &'static str {
264 "llvm-jit"
265 }
266}
267
268struct LlvmRendererWrapper {
270 device: DeviceSpec,
271}
272
273impl Renderer for LlvmRendererWrapper {
274 fn render(&self, ast: &Arc<UOp>, name: Option<&str>) -> Result<ProgramSpec> {
275 let rendered = svod_codegen::llvm::text::render(ast, name.or(Some("kernel")))
276 .map_err(|e| svod_device::Error::Runtime { message: format!("LLVM rendering failed: {}", e) })?;
277
278 let mut spec = ProgramSpec::new(rendered.name.clone(), rendered.code.clone(), self.device.clone(), ast.clone());
279
280 spec.set_var_names(rendered.var_names.clone());
281 spec.apply_derived_metadata_from_ast();
282 if spec.buf_count == 0 {
283 spec.buf_count = rendered.buffer_args.len();
284 }
285
286 Ok(spec)
287 }
288
289 fn device(&self) -> &DeviceSpec {
290 &self.device
291 }
292}
293
294fn create_llvm_program(spec: &svod_device::device::CompiledSpec) -> Result<Box<dyn Program>> {
296 let src = spec.src.as_ref().ok_or_else(|| svod_device::Error::Runtime {
297 message: "LLVM JIT requires source code in CompiledSpec".to_string(),
298 })?;
299
300 let kernel = crate::LlvmKernel::compile_ir(src, &spec.name, &spec.name, spec.var_names.clone(), spec.buf_count)
301 .map_err(|e| svod_device::Error::Runtime { message: format!("LLVM JIT compilation failed: {}", e) })?;
302
303 Ok(Box::new(LlvmProgram { kernel }))
304}
305
306#[cfg(feature = "mlir")]
311mod mlir_backend {
312 use std::ffi::c_void;
313
314 use super::*;
315
316 type MlirKernelFn = unsafe extern "C" fn(*const *mut u8, *const i64);
317
318 unsafe fn dispatch_mlir_fn(fn_ptr: *const c_void, buffers: &[*mut u8], vals: &[i64]) {
319 let kernel: MlirKernelFn = unsafe { std::mem::transmute(fn_ptr) };
320 let buffer_usizes: Vec<usize> = buffers.iter().map(|&ptr| ptr as usize).collect();
321 let bufs_ptr = buffer_usizes.as_ptr() as *const *mut u8;
322 unsafe {
323 kernel(bufs_ptr, vals.as_ptr());
324 }
325 }
326
327 unsafe fn execute_mlir_parallel(
328 fn_ptr: *const c_void,
329 buffers: &[*mut u8],
330 vals: &[i64],
331 var_names: &[String],
332 core_count: usize,
333 ) -> Result<()> {
334 use rayon::prelude::*;
335
336 let core_id_idx = var_names.iter().position(|n| n == "core_id").ok_or_else(|| svod_device::Error::Runtime {
337 message: "parallel MLIR CPU launch requires core_id runtime variable".to_string(),
338 })?;
339 let fn_ptr_usize = fn_ptr as usize;
340
341 let buf_ptr = buffers.as_ptr() as usize;
343 let buf_len = buffers.len();
344 let vals = vals.to_vec();
345
346 if rayon::current_thread_index().is_some() {
348 for core_id in 0..core_count {
349 let bufs = unsafe { std::slice::from_raw_parts(buf_ptr as *const *mut u8, buf_len) };
350 let mut thread_vals = vals.clone();
351 thread_vals[core_id_idx] = core_id as i64;
352 unsafe { dispatch_mlir_fn(fn_ptr_usize as *const c_void, bufs, &thread_vals) };
353 }
354 return Ok(());
355 }
356
357 (0..core_count).into_par_iter().for_each(|core_id| {
358 let bufs = unsafe { std::slice::from_raw_parts(buf_ptr as *const *mut u8, buf_len) };
359 let mut thread_vals = vals.clone();
360 thread_vals[core_id_idx] = core_id as i64;
361 unsafe { dispatch_mlir_fn(fn_ptr_usize as *const c_void, bufs, &thread_vals) };
362 });
363
364 Ok(())
365 }
366
367 pub struct MlirProgram {
369 pub kernel: crate::mlir::MlirKernel,
370 }
371
372 impl Program for MlirProgram {
373 unsafe fn execute(
374 &self,
375 buffers: &[*mut u8],
376 vals: &[i64],
377 global_size: Option<[usize; 3]>,
378 _local_size: Option<[usize; 3]>,
379 ) -> Result<()> {
380 let core_count = global_size.map(|[tc, _, _]| tc).filter(|&tc| tc > 1);
381 let fn_ptr = self.kernel.fn_ptr();
382
383 if let Some(count) = core_count {
384 unsafe { execute_mlir_parallel(fn_ptr, buffers, vals, self.kernel.var_names(), count) }
385 } else {
386 unsafe { dispatch_mlir_fn(fn_ptr, buffers, vals) };
387 Ok(())
388 }
389 }
390
391 fn name(&self) -> &str {
392 self.kernel.name()
393 }
394 }
395
396 pub struct MlirRendererWrapper {
398 pub device: DeviceSpec,
399 }
400
401 impl Renderer for MlirRendererWrapper {
402 fn render(&self, ast: &Arc<UOp>, name: Option<&str>) -> Result<ProgramSpec> {
403 let rendered = svod_codegen::mlir::render(ast, name.or(Some("kernel")))
404 .map_err(|e| svod_device::Error::Runtime { message: format!("MLIR rendering failed: {}", e) })?;
405
406 let mut spec =
407 ProgramSpec::new(rendered.name.clone(), rendered.code.clone(), self.device.clone(), ast.clone());
408
409 spec.set_var_names(rendered.var_names.clone());
410 spec.apply_derived_metadata_from_ast();
411 if spec.buf_count == 0 {
412 spec.buf_count = rendered.buffer_args.len();
413 }
414
415 Ok(spec)
416 }
417
418 fn device(&self) -> &DeviceSpec {
419 &self.device
420 }
421
422 fn decompositor(&self) -> Option<svod_ir::pattern::TypedPatternMatcher<()>> {
423 use svod_ir::decompositions::ptrcat_decomposition_patterns;
424 Some(ptrcat_decomposition_patterns())
425 }
426 }
427
428 pub struct MlirCompiler;
430
431 impl Compiler for MlirCompiler {
432 fn compile(&self, spec: &svod_device::device::ProgramSpec) -> Result<svod_device::device::CompiledSpec> {
433 let mut compiled = svod_device::device::CompiledSpec::from_source(
434 spec.name.clone(),
435 spec.src.clone(),
436 spec.ast.clone(),
437 spec.buf_count,
438 );
439 compiled.var_names = spec.var_names.clone();
440 compiled.global_size = spec.global_size.clone();
441 compiled.local_size = spec.local_size.clone();
442 Ok(compiled)
443 }
444
445 fn cache_key(&self) -> &'static str {
446 "mlir-exec-engine"
447 }
448 }
449
450 pub fn create_mlir_program(spec: &svod_device::device::CompiledSpec) -> Result<Box<dyn Program>> {
452 let src = spec.src.as_ref().ok_or_else(|| svod_device::Error::Runtime {
453 message: "MLIR backend requires source code (MLIR text) in CompiledSpec".to_string(),
454 })?;
455
456 let kernel = crate::mlir::MlirKernel::compile(src, &spec.name, spec.var_names.clone()).map_err(|e| {
457 svod_device::Error::Runtime { message: format!("MLIR ExecutionEngine compilation failed: {}", e) }
458 })?;
459
460 Ok(Box::new(MlirProgram { kernel }))
461 }
462}
463
464#[cfg(feature = "mlir")]
465use mlir_backend::{MlirCompiler, MlirRendererWrapper, create_mlir_program};
466
467pub fn create_cpu_device(registry: &DeviceRegistry) -> Result<Device> {
477 create_cpu_device_with_backend(registry, CpuBackend::from_env())
478}
479
480pub fn create_cpu_device_with_backend(registry: &DeviceRegistry, backend: CpuBackend) -> Result<Device> {
482 let device_spec = DeviceSpec::Cpu;
483 let allocator = registry.get(&device_spec)?;
484
485 match backend {
486 CpuBackend::Clang => {
487 let renderer = Arc::new(ClangRendererWrapper { device: device_spec.clone() });
488 let compiler = Arc::new(ClangCompiler);
489 let runtime: RuntimeFactory = Arc::new(create_clang_program);
490 Ok(Device::new(device_spec, allocator, renderer, compiler, runtime))
491 }
492 CpuBackend::Llvm => {
493 let renderer = Arc::new(LlvmRendererWrapper { device: device_spec.clone() });
494 let compiler = Arc::new(LlvmCompiler);
495 let runtime: RuntimeFactory = Arc::new(create_llvm_program);
496 Ok(Device::new(device_spec, allocator, renderer, compiler, runtime))
497 }
498 #[cfg(feature = "mlir")]
499 CpuBackend::Mlir => {
500 let renderer = Arc::new(MlirRendererWrapper { device: device_spec.clone() });
501 let compiler = Arc::new(MlirCompiler);
502 let runtime: RuntimeFactory = Arc::new(create_mlir_program);
503 Ok(Device::new(device_spec, allocator, renderer, compiler, runtime))
504 }
505 }
506}