1use anyhow::{anyhow, Result};
2use num_complex::Complex;
3use std::collections::HashMap;
4use std::ffi::c_void;
5use std::fmt;
6use std::mem::size_of;
7use std::slice::{from_raw_parts, from_raw_parts_mut};
8use wide::{f64x2, f64x4};
9
10type ExternalFunction<T> = Box<dyn Fn(&[T]) -> T + Send + Sync>;
11
12use crate::code::{BinaryFunc, BinaryFuncCplx, Func, UnaryFunc, UnaryFuncCplx, VirtualTable};
13use crate::config::SLICE_CAP;
14use crate::types::{ElemType, Element};
15
16#[derive(Debug, Clone)]
17pub struct RawBox {
18 func_ptr: *mut c_void,
19 elem_type: ElemType,
20}
21
22unsafe impl Send for RawBox {}
23unsafe impl Sync for RawBox {}
24
25#[cfg(target_arch = "aarch64")]
26type NativeSimd = f64x2;
27
28#[cfg(target_arch = "x86_64")]
29type NativeSimd = f64x4;
30
31pub unsafe extern "C" fn trampoline_homogenous<T>(
32 env: *const c_void,
33 slice_ptr: *const T,
34 slice_len: usize,
35 res: *mut T,
36) -> bool
37where
38 T: Sized + Copy + Default,
39{
40 let closure = &*(env as *const ExternalFunction<T>);
41 let slice = from_raw_parts(slice_ptr, slice_len);
42 *res = closure(slice);
43 false
44}
45
46pub unsafe extern "C" fn trampoline_call_scalar<T, F>(
47 env: *const c_void,
48 slice_ptr: *const T,
49 slice_len: usize,
50 res: *mut T,
51) -> bool
52where
53 T: Sized + Copy + Default,
54 F: Sized + Copy + Default,
55{
56 assert!(slice_len <= SLICE_CAP && size_of::<T>() > size_of::<F>());
57
58 let closure = &*(env as *const ExternalFunction<F>);
59 let mut buf = [F::default(); SLICE_CAP];
60 let step = size_of::<T>() / size_of::<F>();
61 let slice = from_raw_parts(slice_ptr as *mut F, step * slice_len);
62 let res = from_raw_parts_mut(res as *mut F, step);
63
64 for i in 0..step {
65 for j in 0..slice_len {
66 buf[j] = slice[j * step + i];
67 }
68 res[i] = closure(&buf[..slice_len]);
69 }
70
71 true
75}
76
77unsafe fn real<T: Element>(x: T) -> f64 {
78 let p = &x as *const _ as *const f64;
79 *p
80}
81
82unsafe fn imag<T: Element>(x: T) -> f64 {
83 match T::get_type(x) {
84 ElemType::RealF64(_) | ElemType::RealF64x2(_) | ElemType::RealF64x4(_) => 0.0,
85 ElemType::ComplexF64(x) => x.re,
86 ElemType::ComplexF64x2(x) => real(x.re),
87 ElemType::ComplexF64x4(x) => real(x.re),
88 }
89}
90
91pub unsafe extern "C" fn trampoline_call_simd<T, F>(
92 env: *const c_void,
93 slice_ptr: *const T,
94 slice_len: usize,
95 res: *mut T,
96) -> bool
97where
98 T: Sized + Copy + Element,
99 F: Sized + Copy + Element,
100{
101 assert!(slice_len <= SLICE_CAP && size_of::<T>() < size_of::<F>());
102
103 let closure = &*(env as *const ExternalFunction<F>);
104 let buf = [F::default(); SLICE_CAP];
105 let step = size_of::<F>() / size_of::<T>();
106 let slice = from_raw_parts(slice_ptr, slice_len);
107 let p = from_raw_parts_mut(buf.as_ptr() as *mut T, step * slice_len);
108
109 for j in 0..slice_len {
110 for i in 0..step {
111 p[j * step + i] = slice[j];
112 }
113 }
114
115 let val = closure(&buf[..slice_len]);
116 let mut res: *mut f64 = res as _;
117 *res = real(val);
118 res = res.add(1);
119 *res = imag(val);
120 false
121}
122
123#[derive(Clone, Default)]
124pub struct Defuns {
125 pub funcs: HashMap<String, Func>,
126 pub boxes: Vec<RawBox>,
127}
128
129impl fmt::Debug for Defuns {
130 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131 writeln!(f, "{:?}", &self.funcs)?;
132 Ok(())
133 }
134}
135
136impl Defuns {
137 pub fn new() -> Defuns {
138 Defuns {
139 funcs: HashMap::new(),
140 boxes: Vec::new(),
141 }
142 }
143
144 pub fn add_func(&mut self, name: &str, p: *const usize, num_args: usize) {
145 match num_args {
146 1 => {
147 let f: UnaryFunc = unsafe { std::mem::transmute(p) };
148 self.funcs.insert(name.to_string(), Func::Unary(f));
149 }
150 2 => {
151 let f: BinaryFunc = unsafe { std::mem::transmute(p) };
152 self.funcs.insert(name.to_string(), Func::Binary(f));
153 }
154 _ => {
155 panic!("only unary and binary functions are supported")
156 }
157 }
158 }
159
160 pub fn add_unary(&mut self, name: &str, f: UnaryFunc) {
161 self.funcs.insert(name.to_string(), Func::Unary(f));
162 }
163
164 pub fn add_binary(&mut self, name: &str, f: BinaryFunc) {
165 self.funcs.insert(name.to_string(), Func::Binary(f));
166 }
167
168 pub fn add_unary_complex(&mut self, name: &str, f: UnaryFuncCplx) {
169 self.funcs
170 .insert(format!("cplx_{}", name), Func::UnaryCplx(f));
171 }
172
173 pub fn add_binary_complex(&mut self, name: &str, f: BinaryFuncCplx) {
174 self.funcs
175 .insert(format!("cplx_{}", name), Func::BinaryCplx(f));
176 }
177
178 pub fn add_sliced_func<T>(&mut self, name: &str, closure: ExternalFunction<T>) -> Result<()>
179 where
180 T: Copy + Sized + Element,
181 {
182 if VirtualTable::from_str(name).is_ok() {
183 return Err(anyhow!("cannot redefine function {}.", &name));
184 }
185
186 let ext = Box::new(closure);
187 let env = ext.as_ref() as *const _ as *const c_void;
188
189 let trampoline: *const c_void = match T::get_type(T::default()) {
190 ElemType::RealF64(_) | ElemType::ComplexF64(_) => {
191 trampoline_homogenous::<T> as *const c_void
192 }
193 _ => trampoline_call_simd::<f64, T> as *const c_void,
194 };
195
196 let trampoline_simd: *const c_void = match T::get_type(T::default()) {
197 ElemType::RealF64(_) => trampoline_call_scalar::<NativeSimd, T> as *const c_void,
198 ElemType::ComplexF64(_) => {
199 trampoline_call_scalar::<Complex<NativeSimd>, T> as *const c_void
200 }
201 _ => trampoline_homogenous::<T> as *const c_void,
202 };
203
204 let op = format!("${}", name);
205
206 self.funcs.insert(
207 op,
208 Func::Slice {
209 f_scalar: trampoline,
210 f_simd: trampoline_simd,
211 env,
212 },
213 );
214
215 let func_ptr = Box::into_raw(ext);
216
217 self.boxes.push(RawBox {
218 func_ptr: func_ptr as *mut _,
219 elem_type: T::get_type(T::default()),
220 });
221
222 Ok(())
223 }
224
225 pub fn len(&self) -> usize {
226 self.funcs.len()
227 }
228
229 pub fn is_empty(&self) -> bool {
230 self.len() == 0
231 }
232}
233
234impl Drop for RawBox {
235 fn drop(&mut self) {
236 unsafe {
237 match self.elem_type {
238 ElemType::RealF64(_) => {
239 let p: *mut ExternalFunction<f64> = self.func_ptr as *mut _;
240 let _: Box<ExternalFunction<f64>> = Box::from_raw(p);
241 }
242 ElemType::ComplexF64(_) => {
243 let p: *mut ExternalFunction<Complex<f64>> = self.func_ptr as *mut _;
244 let _: Box<ExternalFunction<Complex<f64>>> = Box::from_raw(p);
245 }
246 ElemType::RealF64x2(_) => {
247 let p: *mut ExternalFunction<f64x2> = self.func_ptr as *mut _;
248 let _: Box<ExternalFunction<f64x2>> = Box::from_raw(p);
249 }
250 ElemType::ComplexF64x2(_) => {
251 let p: *mut ExternalFunction<Complex<f64x2>> = self.func_ptr as *mut _;
252 let _: Box<ExternalFunction<Complex<f64x2>>> = Box::from_raw(p);
253 }
254 ElemType::RealF64x4(_) => {
255 let p: *mut ExternalFunction<f64x4> = self.func_ptr as *mut _;
256 let _: Box<ExternalFunction<f64x4>> = Box::from_raw(p);
257 }
258 ElemType::ComplexF64x4(_) => {
259 let p: *mut ExternalFunction<Complex<f64x4>> = self.func_ptr as *mut _;
260 let _: Box<ExternalFunction<Complex<f64x4>>> = Box::from_raw(p);
261 }
262 }
263 }
264 }
265}