Skip to main content

svod_runtime/devices/
cpu.rs

1//! CPU device implementation with selectable JIT backends.
2//!
3//! This module provides a Device instance for CPU execution using either:
4//! - Clang C codegen (default, human-readable, fast debug cycles)
5//! - LLVM JIT (maximum optimization, slower compilation)
6//!
7//! The backend can be selected via:
8//! - `SVOD_CPU_BACKEND` environment variable ("clang" or "llvm")
9//! - Explicit `create_cpu_device_with_backend()` call
10
11use std::sync::Arc;
12
13use svod_device::Result;
14use svod_device::device::{Compiler, Device, Program, ProgramSpec, Renderer, RuntimeFactory};
15use svod_device::registry::DeviceRegistry;
16use svod_dtype::DeviceSpec;
17use svod_ir::UOp;
18
19use crate::LlvmKernel;
20use crate::clang::ClangKernel;
21use crate::dispatch::KernelCif;
22
23/// CPU backend selection.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
25pub enum CpuBackend {
26    /// Clang C codegen backend (default).
27    /// Generates C source, compiles with clang, loads via dlopen.
28    #[default]
29    Clang,
30    /// LLVM JIT backend.
31    /// Maximum optimization, slower compilation.
32    Llvm,
33    /// MLIR backend.
34    /// Generates MLIR, lowers to LLVM IR, then JIT compiles.
35    #[cfg(feature = "mlir")]
36    Mlir,
37}
38
39impl CpuBackend {
40    /// Select backend from environment variable SVOD_CPU_BACKEND.
41    pub fn from_env() -> Self {
42        match std::env::var("SVOD_CPU_BACKEND").as_deref() {
43            Ok("clang") | Ok("CLANG") => CpuBackend::Clang,
44            Ok("llvm") | Ok("LLVM") => CpuBackend::Llvm,
45            #[cfg(feature = "mlir")]
46            Ok("mlir") | Ok("MLIR") => CpuBackend::Mlir,
47            _ => CpuBackend::default(),
48        }
49    }
50}
51
52// =============================================================================
53// Shared parallel execution
54// =============================================================================
55
56/// Execute a kernel function pointer in parallel across multiple threads.
57///
58/// # Safety
59///
60/// Buffer safety is guaranteed by the shift_to() transformation:
61/// - Each core_id maps to disjoint output indices
62/// - Index formula: `output[core_id * chunk_size + local_idx]`
63///
64/// Same buffer pointers can be safely passed to all threads because:
65/// 1. Input buffers: Read-only access (no data race)
66/// 2. Output buffers: Disjoint write regions per thread
67unsafe fn execute_parallel(
68    cif: &KernelCif,
69    fn_ptr: *const (),
70    buffers: &[*mut u8],
71    vals: &[i64],
72    var_names: &[String],
73    core_count: usize,
74) -> Result<()> {
75    use rayon::prelude::*;
76
77    let core_id_idx = var_names.iter().position(|n| n == "core_id").ok_or_else(|| svod_device::Error::Runtime {
78        message: "parallel CPU launch requires core_id runtime variable".to_string(),
79    })?;
80    let fn_ptr_usize = fn_ptr as usize;
81
82    // Convert raw pointers to usize for Send-safe cross-thread sharing.
83    // Safety: buffer pointers are read-only and point to disjoint write
84    // regions per thread (guaranteed by shift_to transformation).
85    let buf_ptr = buffers.as_ptr() as usize;
86    let buf_len = buffers.len();
87
88    // Nested parallelism policy: if we're already inside rayon work, avoid
89    // spawning another parallel loop for core_id kernels.
90    if rayon::current_thread_index().is_some() {
91        for core_id in 0..core_count {
92            let bufs = unsafe { std::slice::from_raw_parts(buf_ptr as *const *mut u8, buf_len) };
93            unsafe {
94                cif.dispatch(fn_ptr_usize as *const (), bufs, vals, Some((core_id_idx, core_id)));
95            }
96        }
97        return Ok(());
98    }
99
100    (0..core_count).into_par_iter().for_each(|core_id| {
101        let bufs = unsafe { std::slice::from_raw_parts(buf_ptr as *const *mut u8, buf_len) };
102        unsafe {
103            cif.dispatch(fn_ptr_usize as *const (), bufs, vals, Some((core_id_idx, core_id)));
104        }
105    });
106
107    Ok(())
108}
109
110// =============================================================================
111// Shared kernel execution
112// =============================================================================
113
114/// Execute a kernel: parallel if global_size > 1, otherwise single-threaded.
115unsafe fn execute_kernel(
116    cif: &KernelCif,
117    fn_ptr: *const (),
118    buffers: &[*mut u8],
119    vals: &[i64],
120    var_names: &[String],
121    global_size: Option<[usize; 3]>,
122) -> Result<()> {
123    let core_count = global_size.map(|[tc, _, _]| tc).filter(|&tc| tc > 1);
124    if let Some(count) = core_count {
125        unsafe { execute_parallel(cif, fn_ptr, buffers, vals, var_names, count) }
126    } else {
127        unsafe { cif.dispatch(fn_ptr, buffers, vals, None) };
128        Ok(())
129    }
130}
131
132// =============================================================================
133// Clang Backend
134// =============================================================================
135
136/// Clang program wrapper implementing the Program trait.
137struct ClangProgram {
138    kernel: ClangKernel,
139}
140
141impl Program for ClangProgram {
142    unsafe fn execute(
143        &self,
144        buffers: &[*mut u8],
145        vals: &[i64],
146        global_size: Option<[usize; 3]>,
147        _local_size: Option<[usize; 3]>,
148    ) -> Result<()> {
149        unsafe {
150            execute_kernel(self.kernel.cif(), self.kernel.fn_ptr(), buffers, vals, self.kernel.var_names(), global_size)
151        }
152    }
153
154    fn name(&self) -> &str {
155        self.kernel.name()
156    }
157}
158
159/// Clang renderer wrapper implementing the Renderer trait.
160struct ClangRendererWrapper {
161    device: DeviceSpec,
162}
163
164impl Renderer for ClangRendererWrapper {
165    fn render(&self, ast: &Arc<UOp>, name: Option<&str>) -> Result<ProgramSpec> {
166        let rendered = svod_codegen::c::render(ast, name.or(Some("kernel")))
167            .map_err(|e| svod_device::Error::Runtime { message: format!("C rendering failed: {}", e) })?;
168
169        let mut spec = ProgramSpec::new(rendered.name.clone(), rendered.code.clone(), self.device.clone(), ast.clone());
170
171        spec.set_var_names(rendered.var_names.clone());
172        spec.apply_derived_metadata_from_ast();
173        if spec.buf_count == 0 {
174            spec.buf_count = rendered.buffer_args.len();
175        }
176
177        Ok(spec)
178    }
179
180    fn device(&self) -> &DeviceSpec {
181        &self.device
182    }
183}
184
185/// Clang compiler - passes C source through for clang compilation.
186struct ClangCompiler;
187
188impl Compiler for ClangCompiler {
189    fn compile(&self, spec: &ProgramSpec) -> Result<svod_device::device::CompiledSpec> {
190        let mut compiled = svod_device::device::CompiledSpec::from_source(
191            spec.name.clone(),
192            spec.src.clone(),
193            spec.ast.clone(),
194            spec.buf_count,
195        );
196        compiled.var_names = spec.var_names.clone();
197        compiled.global_size = spec.global_size.clone();
198        compiled.local_size = spec.local_size.clone();
199        Ok(compiled)
200    }
201
202    fn cache_key(&self) -> &'static str {
203        "clang"
204    }
205}
206
207/// Runtime factory for creating Clang programs.
208fn create_clang_program(spec: &svod_device::device::CompiledSpec) -> Result<Box<dyn Program>> {
209    let src = spec.src.as_ref().ok_or_else(|| svod_device::Error::Runtime {
210        message: "Clang backend requires source code in CompiledSpec".to_string(),
211    })?;
212
213    let kernel = ClangKernel::compile(src, &spec.name, spec.var_names.clone(), spec.buf_count)
214        .map_err(|e| svod_device::Error::Runtime { message: format!("Clang compilation failed: {}", e) })?;
215
216    Ok(Box::new(ClangProgram { kernel }))
217}
218
219// =============================================================================
220// LLVM Backend
221// =============================================================================
222
223/// LLVM program wrapper implementing the Program trait.
224struct LlvmProgram {
225    kernel: LlvmKernel,
226}
227
228impl Program for LlvmProgram {
229    unsafe fn execute(
230        &self,
231        buffers: &[*mut u8],
232        vals: &[i64],
233        global_size: Option<[usize; 3]>,
234        _local_size: Option<[usize; 3]>,
235    ) -> Result<()> {
236        unsafe {
237            execute_kernel(self.kernel.cif(), self.kernel.fn_ptr(), buffers, vals, self.kernel.var_names(), global_size)
238        }
239    }
240
241    fn name(&self) -> &str {
242        self.kernel.name()
243    }
244}
245
246/// LLVM compiler implementing the Compiler trait.
247struct LlvmCompiler;
248
249impl Compiler for LlvmCompiler {
250    fn compile(&self, spec: &svod_device::device::ProgramSpec) -> Result<svod_device::device::CompiledSpec> {
251        let mut compiled = svod_device::device::CompiledSpec::from_source(
252            spec.name.clone(),
253            spec.src.clone(),
254            spec.ast.clone(),
255            spec.buf_count,
256        );
257        compiled.var_names = spec.var_names.clone();
258        compiled.global_size = spec.global_size.clone();
259        compiled.local_size = spec.local_size.clone();
260        Ok(compiled)
261    }
262
263    fn cache_key(&self) -> &'static str {
264        "llvm-jit"
265    }
266}
267
268/// LLVM renderer wrapper implementing the Renderer trait.
269struct LlvmRendererWrapper {
270    device: DeviceSpec,
271}
272
273impl Renderer for LlvmRendererWrapper {
274    fn render(&self, ast: &Arc<UOp>, name: Option<&str>) -> Result<ProgramSpec> {
275        let rendered = svod_codegen::llvm::text::render(ast, name.or(Some("kernel")))
276            .map_err(|e| svod_device::Error::Runtime { message: format!("LLVM rendering failed: {}", e) })?;
277
278        let mut spec = ProgramSpec::new(rendered.name.clone(), rendered.code.clone(), self.device.clone(), ast.clone());
279
280        spec.set_var_names(rendered.var_names.clone());
281        spec.apply_derived_metadata_from_ast();
282        if spec.buf_count == 0 {
283            spec.buf_count = rendered.buffer_args.len();
284        }
285
286        Ok(spec)
287    }
288
289    fn device(&self) -> &DeviceSpec {
290        &self.device
291    }
292}
293
294/// Runtime factory for creating LLVM programs.
295fn create_llvm_program(spec: &svod_device::device::CompiledSpec) -> Result<Box<dyn Program>> {
296    let src = spec.src.as_ref().ok_or_else(|| svod_device::Error::Runtime {
297        message: "LLVM JIT requires source code in CompiledSpec".to_string(),
298    })?;
299
300    let kernel = crate::LlvmKernel::compile_ir(src, &spec.name, &spec.name, spec.var_names.clone(), spec.buf_count)
301        .map_err(|e| svod_device::Error::Runtime { message: format!("LLVM JIT compilation failed: {}", e) })?;
302
303    Ok(Box::new(LlvmProgram { kernel }))
304}
305
306// =============================================================================
307// MLIR Backend
308// =============================================================================
309
310#[cfg(feature = "mlir")]
311mod mlir_backend {
312    use std::ffi::c_void;
313
314    use super::*;
315
316    type MlirKernelFn = unsafe extern "C" fn(*const *mut u8, *const i64);
317
318    unsafe fn dispatch_mlir_fn(fn_ptr: *const c_void, buffers: &[*mut u8], vals: &[i64]) {
319        let kernel: MlirKernelFn = unsafe { std::mem::transmute(fn_ptr) };
320        let buffer_usizes: Vec<usize> = buffers.iter().map(|&ptr| ptr as usize).collect();
321        let bufs_ptr = buffer_usizes.as_ptr() as *const *mut u8;
322        unsafe {
323            kernel(bufs_ptr, vals.as_ptr());
324        }
325    }
326
327    unsafe fn execute_mlir_parallel(
328        fn_ptr: *const c_void,
329        buffers: &[*mut u8],
330        vals: &[i64],
331        var_names: &[String],
332        core_count: usize,
333    ) -> Result<()> {
334        use rayon::prelude::*;
335
336        let core_id_idx = var_names.iter().position(|n| n == "core_id").ok_or_else(|| svod_device::Error::Runtime {
337            message: "parallel MLIR CPU launch requires core_id runtime variable".to_string(),
338        })?;
339        let fn_ptr_usize = fn_ptr as usize;
340
341        // Convert raw pointers to usize for Send-safe cross-thread sharing.
342        let buf_ptr = buffers.as_ptr() as usize;
343        let buf_len = buffers.len();
344        let vals = vals.to_vec();
345
346        // Avoid nested parallelism when already executing inside rayon worker.
347        if rayon::current_thread_index().is_some() {
348            for core_id in 0..core_count {
349                let bufs = unsafe { std::slice::from_raw_parts(buf_ptr as *const *mut u8, buf_len) };
350                let mut thread_vals = vals.clone();
351                thread_vals[core_id_idx] = core_id as i64;
352                unsafe { dispatch_mlir_fn(fn_ptr_usize as *const c_void, bufs, &thread_vals) };
353            }
354            return Ok(());
355        }
356
357        (0..core_count).into_par_iter().for_each(|core_id| {
358            let bufs = unsafe { std::slice::from_raw_parts(buf_ptr as *const *mut u8, buf_len) };
359            let mut thread_vals = vals.clone();
360            thread_vals[core_id_idx] = core_id as i64;
361            unsafe { dispatch_mlir_fn(fn_ptr_usize as *const c_void, bufs, &thread_vals) };
362        });
363
364        Ok(())
365    }
366
367    /// MLIR program wrapper using ExecutionEngine.
368    pub struct MlirProgram {
369        pub kernel: crate::mlir::MlirKernel,
370    }
371
372    impl Program for MlirProgram {
373        unsafe fn execute(
374            &self,
375            buffers: &[*mut u8],
376            vals: &[i64],
377            global_size: Option<[usize; 3]>,
378            _local_size: Option<[usize; 3]>,
379        ) -> Result<()> {
380            let core_count = global_size.map(|[tc, _, _]| tc).filter(|&tc| tc > 1);
381            let fn_ptr = self.kernel.fn_ptr();
382
383            if let Some(count) = core_count {
384                unsafe { execute_mlir_parallel(fn_ptr, buffers, vals, self.kernel.var_names(), count) }
385            } else {
386                unsafe { dispatch_mlir_fn(fn_ptr, buffers, vals) };
387                Ok(())
388            }
389        }
390
391        fn name(&self) -> &str {
392            self.kernel.name()
393        }
394    }
395
396    /// MLIR renderer wrapper implementing the Renderer trait.
397    pub struct MlirRendererWrapper {
398        pub device: DeviceSpec,
399    }
400
401    impl Renderer for MlirRendererWrapper {
402        fn render(&self, ast: &Arc<UOp>, name: Option<&str>) -> Result<ProgramSpec> {
403            let rendered = svod_codegen::mlir::render(ast, name.or(Some("kernel")))
404                .map_err(|e| svod_device::Error::Runtime { message: format!("MLIR rendering failed: {}", e) })?;
405
406            let mut spec =
407                ProgramSpec::new(rendered.name.clone(), rendered.code.clone(), self.device.clone(), ast.clone());
408
409            spec.set_var_names(rendered.var_names.clone());
410            spec.apply_derived_metadata_from_ast();
411            if spec.buf_count == 0 {
412                spec.buf_count = rendered.buffer_args.len();
413            }
414
415            Ok(spec)
416        }
417
418        fn device(&self) -> &DeviceSpec {
419            &self.device
420        }
421
422        fn decompositor(&self) -> Option<svod_ir::pattern::TypedPatternMatcher<()>> {
423            use svod_ir::decompositions::ptrcat_decomposition_patterns;
424            Some(ptrcat_decomposition_patterns())
425        }
426    }
427
428    /// MLIR compiler implementing the Compiler trait.
429    pub struct MlirCompiler;
430
431    impl Compiler for MlirCompiler {
432        fn compile(&self, spec: &svod_device::device::ProgramSpec) -> Result<svod_device::device::CompiledSpec> {
433            let mut compiled = svod_device::device::CompiledSpec::from_source(
434                spec.name.clone(),
435                spec.src.clone(),
436                spec.ast.clone(),
437                spec.buf_count,
438            );
439            compiled.var_names = spec.var_names.clone();
440            compiled.global_size = spec.global_size.clone();
441            compiled.local_size = spec.local_size.clone();
442            Ok(compiled)
443        }
444
445        fn cache_key(&self) -> &'static str {
446            "mlir-exec-engine"
447        }
448    }
449
450    /// Runtime factory for creating MLIR programs.
451    pub fn create_mlir_program(spec: &svod_device::device::CompiledSpec) -> Result<Box<dyn Program>> {
452        let src = spec.src.as_ref().ok_or_else(|| svod_device::Error::Runtime {
453            message: "MLIR backend requires source code (MLIR text) in CompiledSpec".to_string(),
454        })?;
455
456        let kernel = crate::mlir::MlirKernel::compile(src, &spec.name, spec.var_names.clone()).map_err(|e| {
457            svod_device::Error::Runtime { message: format!("MLIR ExecutionEngine compilation failed: {}", e) }
458        })?;
459
460        Ok(Box::new(MlirProgram { kernel }))
461    }
462}
463
464#[cfg(feature = "mlir")]
465use mlir_backend::{MlirCompiler, MlirRendererWrapper, create_mlir_program};
466
467// =============================================================================
468// Public API
469// =============================================================================
470
471/// Create a CPU device with the default backend.
472///
473/// The default backend is selected by:
474/// 1. `SVOD_CPU_BACKEND` environment variable ("clang" or "llvm")
475/// 2. If not set, defaults to Clang
476pub fn create_cpu_device(registry: &DeviceRegistry) -> Result<Device> {
477    create_cpu_device_with_backend(registry, CpuBackend::from_env())
478}
479
480/// Create a CPU device with a specific backend.
481pub fn create_cpu_device_with_backend(registry: &DeviceRegistry, backend: CpuBackend) -> Result<Device> {
482    let device_spec = DeviceSpec::Cpu;
483    let allocator = registry.get(&device_spec)?;
484
485    match backend {
486        CpuBackend::Clang => {
487            let renderer = Arc::new(ClangRendererWrapper { device: device_spec.clone() });
488            let compiler = Arc::new(ClangCompiler);
489            let runtime: RuntimeFactory = Arc::new(create_clang_program);
490            Ok(Device::new(device_spec, allocator, renderer, compiler, runtime))
491        }
492        CpuBackend::Llvm => {
493            let renderer = Arc::new(LlvmRendererWrapper { device: device_spec.clone() });
494            let compiler = Arc::new(LlvmCompiler);
495            let runtime: RuntimeFactory = Arc::new(create_llvm_program);
496            Ok(Device::new(device_spec, allocator, renderer, compiler, runtime))
497        }
498        #[cfg(feature = "mlir")]
499        CpuBackend::Mlir => {
500            let renderer = Arc::new(MlirRendererWrapper { device: device_spec.clone() });
501            let compiler = Arc::new(MlirCompiler);
502            let runtime: RuntimeFactory = Arc::new(create_mlir_program);
503            Ok(Device::new(device_spec, allocator, renderer, compiler, runtime))
504        }
505    }
506}