1use std::alloc::{Layout, alloc};
25
26const SIMD_THRESHOLD: usize = 16;
28
29#[unsafe(no_mangle)]
36pub extern "C" fn jit_simd_add(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
37 simd_binary_op(a_ptr, b_ptr, len as usize, |a, b| a + b)
38}
39
40#[unsafe(no_mangle)]
42pub extern "C" fn jit_simd_sub(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
43 simd_binary_op(a_ptr, b_ptr, len as usize, |a, b| a - b)
44}
45
46#[unsafe(no_mangle)]
48pub extern "C" fn jit_simd_mul(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
49 simd_binary_op(a_ptr, b_ptr, len as usize, |a, b| a * b)
50}
51
52#[unsafe(no_mangle)]
54pub extern "C" fn jit_simd_div(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
55 simd_binary_op(a_ptr, b_ptr, len as usize, |a, b| a / b)
56}
57
58#[unsafe(no_mangle)]
60pub extern "C" fn jit_simd_max(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
61 simd_binary_op(a_ptr, b_ptr, len as usize, |a, b| a.max(b))
62}
63
64#[unsafe(no_mangle)]
66pub extern "C" fn jit_simd_min(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
67 simd_binary_op(a_ptr, b_ptr, len as usize, |a, b| a.min(b))
68}
69
70#[unsafe(no_mangle)]
76pub extern "C" fn jit_simd_add_scalar(a_ptr: *const f64, scalar: f64, len: u64) -> *mut f64 {
77 simd_scalar_op(a_ptr, scalar, len as usize, |a, s| a + s)
78}
79
80#[unsafe(no_mangle)]
82pub extern "C" fn jit_simd_sub_scalar(a_ptr: *const f64, scalar: f64, len: u64) -> *mut f64 {
83 simd_scalar_op(a_ptr, scalar, len as usize, |a, s| a - s)
84}
85
86#[unsafe(no_mangle)]
88pub extern "C" fn jit_simd_mul_scalar(a_ptr: *const f64, scalar: f64, len: u64) -> *mut f64 {
89 simd_scalar_op(a_ptr, scalar, len as usize, |a, s| a * s)
90}
91
92#[unsafe(no_mangle)]
94pub extern "C" fn jit_simd_div_scalar(a_ptr: *const f64, scalar: f64, len: u64) -> *mut f64 {
95 simd_scalar_op(a_ptr, scalar, len as usize, |a, s| a / s)
96}
97
98#[unsafe(no_mangle)]
104pub extern "C" fn jit_simd_gt(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
105 simd_cmp_op(a_ptr, b_ptr, len as usize, |a, b| a > b)
106}
107
108#[unsafe(no_mangle)]
110pub extern "C" fn jit_simd_lt(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
111 simd_cmp_op(a_ptr, b_ptr, len as usize, |a, b| a < b)
112}
113
114#[unsafe(no_mangle)]
116pub extern "C" fn jit_simd_gte(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
117 simd_cmp_op(a_ptr, b_ptr, len as usize, |a, b| a >= b)
118}
119
120#[unsafe(no_mangle)]
122pub extern "C" fn jit_simd_lte(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
123 simd_cmp_op(a_ptr, b_ptr, len as usize, |a, b| a <= b)
124}
125
126#[unsafe(no_mangle)]
128pub extern "C" fn jit_simd_eq(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
129 simd_cmp_op(a_ptr, b_ptr, len as usize, |a, b| {
130 (a - b).abs() < f64::EPSILON
131 })
132}
133
134#[unsafe(no_mangle)]
136pub extern "C" fn jit_simd_neq(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
137 simd_cmp_op(a_ptr, b_ptr, len as usize, |a, b| {
138 (a - b).abs() >= f64::EPSILON
139 })
140}
141
142#[inline]
148fn alloc_f64_buffer(len: usize) -> *mut f64 {
149 if len == 0 {
150 return std::ptr::null_mut();
151 }
152 let layout =
154 Layout::from_size_align(len * std::mem::size_of::<f64>(), 32).expect("Invalid layout");
155 unsafe { alloc(layout) as *mut f64 }
156}
157
158#[inline]
160fn simd_binary_op<F>(a_ptr: *const f64, b_ptr: *const f64, len: usize, op: F) -> *mut f64
161where
162 F: Fn(f64, f64) -> f64,
163{
164 if a_ptr.is_null() || b_ptr.is_null() || len == 0 {
165 return std::ptr::null_mut();
166 }
167
168 let result = alloc_f64_buffer(len);
169 if result.is_null() {
170 return std::ptr::null_mut();
171 }
172
173 unsafe {
174 let a = std::slice::from_raw_parts(a_ptr, len);
175 let b = std::slice::from_raw_parts(b_ptr, len);
176 let out = std::slice::from_raw_parts_mut(result, len);
177
178 if len >= SIMD_THRESHOLD {
179 let chunks = len / 4;
181 for i in 0..chunks {
182 let idx = i * 4;
183 out[idx] = op(a[idx], b[idx]);
184 out[idx + 1] = op(a[idx + 1], b[idx + 1]);
185 out[idx + 2] = op(a[idx + 2], b[idx + 2]);
186 out[idx + 3] = op(a[idx + 3], b[idx + 3]);
187 }
188 for i in (chunks * 4)..len {
190 out[i] = op(a[i], b[i]);
191 }
192 } else {
193 for i in 0..len {
195 out[i] = op(a[i], b[i]);
196 }
197 }
198 }
199
200 result
201}
202
203#[inline]
205fn simd_scalar_op<F>(a_ptr: *const f64, scalar: f64, len: usize, op: F) -> *mut f64
206where
207 F: Fn(f64, f64) -> f64,
208{
209 if a_ptr.is_null() || len == 0 {
210 return std::ptr::null_mut();
211 }
212
213 let result = alloc_f64_buffer(len);
214 if result.is_null() {
215 return std::ptr::null_mut();
216 }
217
218 unsafe {
219 let a = std::slice::from_raw_parts(a_ptr, len);
220 let out = std::slice::from_raw_parts_mut(result, len);
221
222 if len >= SIMD_THRESHOLD {
223 let chunks = len / 4;
225 for i in 0..chunks {
226 let idx = i * 4;
227 out[idx] = op(a[idx], scalar);
228 out[idx + 1] = op(a[idx + 1], scalar);
229 out[idx + 2] = op(a[idx + 2], scalar);
230 out[idx + 3] = op(a[idx + 3], scalar);
231 }
232 for i in (chunks * 4)..len {
234 out[i] = op(a[i], scalar);
235 }
236 } else {
237 for i in 0..len {
239 out[i] = op(a[i], scalar);
240 }
241 }
242 }
243
244 result
245}
246
247#[inline]
249fn simd_cmp_op<F>(a_ptr: *const f64, b_ptr: *const f64, len: usize, op: F) -> *mut f64
250where
251 F: Fn(f64, f64) -> bool,
252{
253 if a_ptr.is_null() || b_ptr.is_null() || len == 0 {
254 return std::ptr::null_mut();
255 }
256
257 let result = alloc_f64_buffer(len);
258 if result.is_null() {
259 return std::ptr::null_mut();
260 }
261
262 unsafe {
263 let a = std::slice::from_raw_parts(a_ptr, len);
264 let b = std::slice::from_raw_parts(b_ptr, len);
265 let out = std::slice::from_raw_parts_mut(result, len);
266
267 if len >= SIMD_THRESHOLD {
268 let chunks = len / 4;
270 for i in 0..chunks {
271 let idx = i * 4;
272 out[idx] = if op(a[idx], b[idx]) { 1.0 } else { 0.0 };
273 out[idx + 1] = if op(a[idx + 1], b[idx + 1]) { 1.0 } else { 0.0 };
274 out[idx + 2] = if op(a[idx + 2], b[idx + 2]) { 1.0 } else { 0.0 };
275 out[idx + 3] = if op(a[idx + 3], b[idx + 3]) { 1.0 } else { 0.0 };
276 }
277 for i in (chunks * 4)..len {
279 out[i] = if op(a[i], b[i]) { 1.0 } else { 0.0 };
280 }
281 } else {
282 for i in 0..len {
284 out[i] = if op(a[i], b[i]) { 1.0 } else { 0.0 };
285 }
286 }
287 }
288
289 result
290}
291
292#[unsafe(no_mangle)]
294pub extern "C" fn jit_simd_free(ptr: *mut f64, len: u64) {
295 if ptr.is_null() || len == 0 {
296 return;
297 }
298 let layout = Layout::from_size_align(len as usize * std::mem::size_of::<f64>(), 32)
299 .expect("Invalid layout");
300 unsafe {
301 std::alloc::dealloc(ptr as *mut u8, layout);
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_simd_add() {
311 let a = vec![1.0, 2.0, 3.0, 4.0];
312 let b = vec![10.0, 20.0, 30.0, 40.0];
313 let result = jit_simd_add(a.as_ptr(), b.as_ptr(), 4);
314
315 unsafe {
316 assert_eq!(*result, 11.0);
317 assert_eq!(*result.add(1), 22.0);
318 assert_eq!(*result.add(2), 33.0);
319 assert_eq!(*result.add(3), 44.0);
320 }
321 jit_simd_free(result, 4);
322 }
323
324 #[test]
325 fn test_simd_mul_large() {
326 let len = 1000;
327 let a: Vec<f64> = (0..len).map(|i| i as f64).collect();
328 let b: Vec<f64> = (0..len).map(|i| (i * 2) as f64).collect();
329 let result = jit_simd_mul(a.as_ptr(), b.as_ptr(), len as u64);
330
331 unsafe {
332 for i in 0..len {
333 assert_eq!(*result.add(i), (i * i * 2) as f64);
334 }
335 }
336 jit_simd_free(result, len as u64);
337 }
338
339 #[test]
340 fn test_simd_gt() {
341 let a = vec![5.0, 2.0, 8.0, 1.0];
342 let b = vec![3.0, 4.0, 8.0, 0.0];
343 let result = jit_simd_gt(a.as_ptr(), b.as_ptr(), 4);
344
345 unsafe {
346 assert_eq!(*result, 1.0); assert_eq!(*result.add(1), 0.0); assert_eq!(*result.add(2), 0.0); assert_eq!(*result.add(3), 1.0); }
351 jit_simd_free(result, 4);
352 }
353}