Skip to main content

symjit/
defuns.rs

1use anyhow::{anyhow, Result};
2use num_complex::Complex;
3use std::collections::HashMap;
4use std::ffi::c_void;
5use std::fmt;
6use std::mem::size_of;
7use std::slice::{from_raw_parts, from_raw_parts_mut};
8use wide::{f64x2, f64x4};
9
10type ExternalFunction<T> = Box<dyn Fn(&[T]) -> T + Send + Sync>;
11
12use crate::code::{BinaryFunc, BinaryFuncCplx, Func, UnaryFunc, UnaryFuncCplx, VirtualTable};
13use crate::config::SLICE_CAP;
14use crate::types::{ElemType, Element};
15
16#[derive(Debug, Clone)]
17pub struct RawBox {
18    func_ptr: *mut c_void,
19    elem_type: ElemType,
20}
21
22unsafe impl Send for RawBox {}
23unsafe impl Sync for RawBox {}
24
25#[cfg(target_arch = "aarch64")]
26type NativeSimd = f64x2;
27
28#[cfg(target_arch = "x86_64")]
29type NativeSimd = f64x4;
30
31pub unsafe extern "C" fn trampoline_homogenous<T>(
32    env: *const c_void,
33    slice_ptr: *const T,
34    slice_len: usize,
35    res: *mut T,
36) -> bool
37where
38    T: Sized + Copy + Default,
39{
40    let closure = &*(env as *const ExternalFunction<T>);
41    let slice = from_raw_parts(slice_ptr, slice_len);
42    *res = closure(slice);
43    false
44}
45
46pub unsafe extern "C" fn trampoline_call_scalar<T, F>(
47    env: *const c_void,
48    slice_ptr: *const T,
49    slice_len: usize,
50    res: *mut T,
51) -> bool
52where
53    T: Sized + Copy + Default,
54    F: Sized + Copy + Default,
55{
56    assert!(slice_len <= SLICE_CAP && size_of::<T>() > size_of::<F>());
57
58    let closure = &*(env as *const ExternalFunction<F>);
59    let mut buf = [F::default(); SLICE_CAP];
60    let step = size_of::<T>() / size_of::<F>();
61    let slice = from_raw_parts(slice_ptr as *mut F, step * slice_len);
62    let res = from_raw_parts_mut(res as *mut F, step);
63
64    for i in 0..step {
65        for j in 0..slice_len {
66            buf[j] = slice[j * step + i];
67        }
68        res[i] = closure(&buf[..slice_len]);
69    }
70
71    // a return value of true signals the SIMD kernel to shuffle the result.
72    // For example, if T = Complex<f64x2>, at this stage `res` is
73    // `x1 y1 x2 y2` but should be `x1 x2 y1 y2`.
74    true
75}
76
77unsafe fn real<T: Element>(x: T) -> f64 {
78    let p = &x as *const _ as *const f64;
79    *p
80}
81
82unsafe fn imag<T: Element>(x: T) -> f64 {
83    match T::get_type(x) {
84        ElemType::RealF64(_) | ElemType::RealF64x2(_) | ElemType::RealF64x4(_) => 0.0,
85        ElemType::ComplexF64(x) => x.re,
86        ElemType::ComplexF64x2(x) => real(x.re),
87        ElemType::ComplexF64x4(x) => real(x.re),
88    }
89}
90
91pub unsafe extern "C" fn trampoline_call_simd<T, F>(
92    env: *const c_void,
93    slice_ptr: *const T,
94    slice_len: usize,
95    res: *mut T,
96) -> bool
97where
98    T: Sized + Copy + Element,
99    F: Sized + Copy + Element,
100{
101    assert!(slice_len <= SLICE_CAP && size_of::<T>() < size_of::<F>());
102
103    let closure = &*(env as *const ExternalFunction<F>);
104    let buf = [F::default(); SLICE_CAP];
105    let step = size_of::<F>() / size_of::<T>();
106    let slice = from_raw_parts(slice_ptr, slice_len);
107    let p = from_raw_parts_mut(buf.as_ptr() as *mut T, step * slice_len);
108
109    for j in 0..slice_len {
110        for i in 0..step {
111            p[j * step + i] = slice[j];
112        }
113    }
114
115    let val = closure(&buf[..slice_len]);
116    let mut res: *mut f64 = res as _;
117    *res = real(val);
118    res = res.add(1);
119    *res = imag(val);
120    false
121}
122
123#[derive(Clone, Default)]
124pub struct Defuns {
125    pub funcs: HashMap<String, Func>,
126    pub boxes: Vec<RawBox>,
127}
128
129impl fmt::Debug for Defuns {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        writeln!(f, "{:?}", &self.funcs)?;
132        Ok(())
133    }
134}
135
136impl Defuns {
137    pub fn new() -> Defuns {
138        Defuns {
139            funcs: HashMap::new(),
140            boxes: Vec::new(),
141        }
142    }
143
144    pub fn add_func(&mut self, name: &str, p: *const usize, num_args: usize) {
145        match num_args {
146            1 => {
147                let f: UnaryFunc = unsafe { std::mem::transmute(p) };
148                self.funcs.insert(name.to_string(), Func::Unary(f));
149            }
150            2 => {
151                let f: BinaryFunc = unsafe { std::mem::transmute(p) };
152                self.funcs.insert(name.to_string(), Func::Binary(f));
153            }
154            _ => {
155                panic!("only unary and binary functions are supported")
156            }
157        }
158    }
159
160    pub fn add_unary(&mut self, name: &str, f: UnaryFunc) {
161        self.funcs.insert(name.to_string(), Func::Unary(f));
162    }
163
164    pub fn add_binary(&mut self, name: &str, f: BinaryFunc) {
165        self.funcs.insert(name.to_string(), Func::Binary(f));
166    }
167
168    pub fn add_unary_complex(&mut self, name: &str, f: UnaryFuncCplx) {
169        self.funcs
170            .insert(format!("cplx_{}", name), Func::UnaryCplx(f));
171    }
172
173    pub fn add_binary_complex(&mut self, name: &str, f: BinaryFuncCplx) {
174        self.funcs
175            .insert(format!("cplx_{}", name), Func::BinaryCplx(f));
176    }
177
178    pub fn add_sliced_func<T>(&mut self, name: &str, closure: ExternalFunction<T>) -> Result<()>
179    where
180        T: Copy + Sized + Element,
181    {
182        if VirtualTable::from_str(name).is_ok() {
183            return Err(anyhow!("cannot redefine function {}.", &name));
184        }
185
186        let ext = Box::new(closure);
187        let env = ext.as_ref() as *const _ as *const c_void;
188
189        let trampoline: *const c_void = match T::get_type(T::default()) {
190            ElemType::RealF64(_) | ElemType::ComplexF64(_) => {
191                trampoline_homogenous::<T> as *const c_void
192            }
193            _ => trampoline_call_simd::<f64, T> as *const c_void,
194        };
195
196        let trampoline_simd: *const c_void = match T::get_type(T::default()) {
197            ElemType::RealF64(_) => trampoline_call_scalar::<NativeSimd, T> as *const c_void,
198            ElemType::ComplexF64(_) => {
199                trampoline_call_scalar::<Complex<NativeSimd>, T> as *const c_void
200            }
201            _ => trampoline_homogenous::<T> as *const c_void,
202        };
203
204        let op = format!("${}", name);
205
206        self.funcs.insert(
207            op,
208            Func::Slice {
209                f_scalar: trampoline,
210                f_simd: trampoline_simd,
211                env,
212            },
213        );
214
215        let func_ptr = Box::into_raw(ext);
216
217        self.boxes.push(RawBox {
218            func_ptr: func_ptr as *mut _,
219            elem_type: T::get_type(T::default()),
220        });
221
222        Ok(())
223    }
224
225    pub fn len(&self) -> usize {
226        self.funcs.len()
227    }
228
229    pub fn is_empty(&self) -> bool {
230        self.len() == 0
231    }
232}
233
234impl Drop for RawBox {
235    fn drop(&mut self) {
236        unsafe {
237            match self.elem_type {
238                ElemType::RealF64(_) => {
239                    let p: *mut ExternalFunction<f64> = self.func_ptr as *mut _;
240                    let _: Box<ExternalFunction<f64>> = Box::from_raw(p);
241                }
242                ElemType::ComplexF64(_) => {
243                    let p: *mut ExternalFunction<Complex<f64>> = self.func_ptr as *mut _;
244                    let _: Box<ExternalFunction<Complex<f64>>> = Box::from_raw(p);
245                }
246                ElemType::RealF64x2(_) => {
247                    let p: *mut ExternalFunction<f64x2> = self.func_ptr as *mut _;
248                    let _: Box<ExternalFunction<f64x2>> = Box::from_raw(p);
249                }
250                ElemType::ComplexF64x2(_) => {
251                    let p: *mut ExternalFunction<Complex<f64x2>> = self.func_ptr as *mut _;
252                    let _: Box<ExternalFunction<Complex<f64x2>>> = Box::from_raw(p);
253                }
254                ElemType::RealF64x4(_) => {
255                    let p: *mut ExternalFunction<f64x4> = self.func_ptr as *mut _;
256                    let _: Box<ExternalFunction<f64x4>> = Box::from_raw(p);
257                }
258                ElemType::ComplexF64x4(_) => {
259                    let p: *mut ExternalFunction<Complex<f64x4>> = self.func_ptr as *mut _;
260                    let _: Box<ExternalFunction<Complex<f64x4>>> = Box::from_raw(p);
261                }
262            }
263        }
264    }
265}