Skip to main content

shape_jit/ffi/
simd.rs

1// Heap allocation audit (PR-9 V8 Gap Closure):
2//   Category A (NaN-boxed returns): 0 sites
3//     (SIMD ops return raw *mut f64 pointers, not NaN-boxed values)
4//   Category B (intermediate/consumed): 0 sites
5//   Category C (heap islands): 12 sites
6//     alloc_f64_buffer() allocations in: jit_simd_add, jit_simd_sub, jit_simd_mul,
7//     jit_simd_div, jit_simd_max, jit_simd_min, jit_simd_add_scalar,
8//     jit_simd_sub_scalar, jit_simd_mul_scalar, jit_simd_div_scalar,
9//     jit_simd_gt, jit_simd_lt, jit_simd_gte, jit_simd_lte, jit_simd_eq, jit_simd_neq.
10//     These raw buffers are returned as *mut f64 and must be freed via jit_simd_free().
11//     The JIT compiler is responsible for pairing each allocation with a free call.
12//     If the JIT fails to emit the free, this is a memory leak (not a GC island per se,
13//     since these are raw allocations outside the NaN-boxing system).
14//     When GC feature enabled, route through gc_allocator.
15//!
16//! Raw Pointer SIMD Operations for JIT
17//!
18//! These functions operate directly on f64 data buffers with zero boxing overhead.
19//! Signature: simd_op(ptr_a: *const f64, ptr_b: *const f64, len: u64) -> *mut f64
20//!
21//! The JIT compiler extracts Series data pointers and lengths, then calls these
22//! functions directly for maximum performance.
23
24use std::alloc::{Layout, alloc};
25
26/// SIMD threshold - arrays smaller than this use scalar fallback
27const SIMD_THRESHOLD: usize = 16;
28
29// ============================================================================
30// Binary Operations (Series + Series)
31// ============================================================================
32
33/// SIMD-accelerated vector addition: result[i] = a[i] + b[i]
34/// Returns a newly allocated buffer that must be freed by the caller
35#[unsafe(no_mangle)]
36pub extern "C" fn jit_simd_add(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
37    simd_binary_op(a_ptr, b_ptr, len as usize, |a, b| a + b)
38}
39
40/// SIMD-accelerated vector subtraction: result[i] = a[i] - b[i]
41#[unsafe(no_mangle)]
42pub extern "C" fn jit_simd_sub(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
43    simd_binary_op(a_ptr, b_ptr, len as usize, |a, b| a - b)
44}
45
46/// SIMD-accelerated vector multiplication: result[i] = a[i] * b[i]
47#[unsafe(no_mangle)]
48pub extern "C" fn jit_simd_mul(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
49    simd_binary_op(a_ptr, b_ptr, len as usize, |a, b| a * b)
50}
51
52/// SIMD-accelerated vector division: result[i] = a[i] / b[i]
53#[unsafe(no_mangle)]
54pub extern "C" fn jit_simd_div(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
55    simd_binary_op(a_ptr, b_ptr, len as usize, |a, b| a / b)
56}
57
58/// SIMD-accelerated element-wise max: result[i] = max(a[i], b[i])
59#[unsafe(no_mangle)]
60pub extern "C" fn jit_simd_max(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
61    simd_binary_op(a_ptr, b_ptr, len as usize, |a, b| a.max(b))
62}
63
64/// SIMD-accelerated element-wise min: result[i] = min(a[i], b[i])
65#[unsafe(no_mangle)]
66pub extern "C" fn jit_simd_min(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
67    simd_binary_op(a_ptr, b_ptr, len as usize, |a, b| a.min(b))
68}
69
70// ============================================================================
71// Scalar Broadcast Operations (Series + scalar)
72// ============================================================================
73
74/// SIMD-accelerated scalar addition: result[i] = a[i] + scalar
75#[unsafe(no_mangle)]
76pub extern "C" fn jit_simd_add_scalar(a_ptr: *const f64, scalar: f64, len: u64) -> *mut f64 {
77    simd_scalar_op(a_ptr, scalar, len as usize, |a, s| a + s)
78}
79
80/// SIMD-accelerated scalar subtraction: result[i] = a[i] - scalar
81#[unsafe(no_mangle)]
82pub extern "C" fn jit_simd_sub_scalar(a_ptr: *const f64, scalar: f64, len: u64) -> *mut f64 {
83    simd_scalar_op(a_ptr, scalar, len as usize, |a, s| a - s)
84}
85
86/// SIMD-accelerated scalar multiplication: result[i] = a[i] * scalar
87#[unsafe(no_mangle)]
88pub extern "C" fn jit_simd_mul_scalar(a_ptr: *const f64, scalar: f64, len: u64) -> *mut f64 {
89    simd_scalar_op(a_ptr, scalar, len as usize, |a, s| a * s)
90}
91
92/// SIMD-accelerated scalar division: result[i] = a[i] / scalar
93#[unsafe(no_mangle)]
94pub extern "C" fn jit_simd_div_scalar(a_ptr: *const f64, scalar: f64, len: u64) -> *mut f64 {
95    simd_scalar_op(a_ptr, scalar, len as usize, |a, s| a / s)
96}
97
98// ============================================================================
99// Comparison Operations (return f64: 1.0 = true, 0.0 = false)
100// ============================================================================
101
102/// SIMD-accelerated greater-than: result[i] = (a[i] > b[i]) ? 1.0 : 0.0
103#[unsafe(no_mangle)]
104pub extern "C" fn jit_simd_gt(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
105    simd_cmp_op(a_ptr, b_ptr, len as usize, |a, b| a > b)
106}
107
108/// SIMD-accelerated less-than: result[i] = (a[i] < b[i]) ? 1.0 : 0.0
109#[unsafe(no_mangle)]
110pub extern "C" fn jit_simd_lt(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
111    simd_cmp_op(a_ptr, b_ptr, len as usize, |a, b| a < b)
112}
113
114/// SIMD-accelerated greater-than-or-equal: result[i] = (a[i] >= b[i]) ? 1.0 : 0.0
115#[unsafe(no_mangle)]
116pub extern "C" fn jit_simd_gte(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
117    simd_cmp_op(a_ptr, b_ptr, len as usize, |a, b| a >= b)
118}
119
120/// SIMD-accelerated less-than-or-equal: result[i] = (a[i] <= b[i]) ? 1.0 : 0.0
121#[unsafe(no_mangle)]
122pub extern "C" fn jit_simd_lte(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
123    simd_cmp_op(a_ptr, b_ptr, len as usize, |a, b| a <= b)
124}
125
126/// SIMD-accelerated equality: result[i] = (a[i] == b[i]) ? 1.0 : 0.0
127#[unsafe(no_mangle)]
128pub extern "C" fn jit_simd_eq(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
129    simd_cmp_op(a_ptr, b_ptr, len as usize, |a, b| {
130        (a - b).abs() < f64::EPSILON
131    })
132}
133
134/// SIMD-accelerated inequality: result[i] = (a[i] != b[i]) ? 1.0 : 0.0
135#[unsafe(no_mangle)]
136pub extern "C" fn jit_simd_neq(a_ptr: *const f64, b_ptr: *const f64, len: u64) -> *mut f64 {
137    simd_cmp_op(a_ptr, b_ptr, len as usize, |a, b| {
138        (a - b).abs() >= f64::EPSILON
139    })
140}
141
142// ============================================================================
143// Helper Functions
144// ============================================================================
145
146/// Allocate an aligned f64 buffer
147#[inline]
148fn alloc_f64_buffer(len: usize) -> *mut f64 {
149    if len == 0 {
150        return std::ptr::null_mut();
151    }
152    // 32-byte alignment for AVX
153    let layout =
154        Layout::from_size_align(len * std::mem::size_of::<f64>(), 32).expect("Invalid layout");
155    unsafe { alloc(layout) as *mut f64 }
156}
157
158/// Generic binary operation with autovectorization hints
159#[inline]
160fn simd_binary_op<F>(a_ptr: *const f64, b_ptr: *const f64, len: usize, op: F) -> *mut f64
161where
162    F: Fn(f64, f64) -> f64,
163{
164    if a_ptr.is_null() || b_ptr.is_null() || len == 0 {
165        return std::ptr::null_mut();
166    }
167
168    let result = alloc_f64_buffer(len);
169    if result.is_null() {
170        return std::ptr::null_mut();
171    }
172
173    unsafe {
174        let a = std::slice::from_raw_parts(a_ptr, len);
175        let b = std::slice::from_raw_parts(b_ptr, len);
176        let out = std::slice::from_raw_parts_mut(result, len);
177
178        if len >= SIMD_THRESHOLD {
179            // Process 4 elements at a time for autovectorization
180            let chunks = len / 4;
181            for i in 0..chunks {
182                let idx = i * 4;
183                out[idx] = op(a[idx], b[idx]);
184                out[idx + 1] = op(a[idx + 1], b[idx + 1]);
185                out[idx + 2] = op(a[idx + 2], b[idx + 2]);
186                out[idx + 3] = op(a[idx + 3], b[idx + 3]);
187            }
188            // Handle remainder
189            for i in (chunks * 4)..len {
190                out[i] = op(a[i], b[i]);
191            }
192        } else {
193            // Scalar fallback for small arrays
194            for i in 0..len {
195                out[i] = op(a[i], b[i]);
196            }
197        }
198    }
199
200    result
201}
202
203/// Generic scalar broadcast operation with autovectorization hints
204#[inline]
205fn simd_scalar_op<F>(a_ptr: *const f64, scalar: f64, len: usize, op: F) -> *mut f64
206where
207    F: Fn(f64, f64) -> f64,
208{
209    if a_ptr.is_null() || len == 0 {
210        return std::ptr::null_mut();
211    }
212
213    let result = alloc_f64_buffer(len);
214    if result.is_null() {
215        return std::ptr::null_mut();
216    }
217
218    unsafe {
219        let a = std::slice::from_raw_parts(a_ptr, len);
220        let out = std::slice::from_raw_parts_mut(result, len);
221
222        if len >= SIMD_THRESHOLD {
223            // Process 4 elements at a time for autovectorization
224            let chunks = len / 4;
225            for i in 0..chunks {
226                let idx = i * 4;
227                out[idx] = op(a[idx], scalar);
228                out[idx + 1] = op(a[idx + 1], scalar);
229                out[idx + 2] = op(a[idx + 2], scalar);
230                out[idx + 3] = op(a[idx + 3], scalar);
231            }
232            // Handle remainder
233            for i in (chunks * 4)..len {
234                out[i] = op(a[i], scalar);
235            }
236        } else {
237            // Scalar fallback for small arrays
238            for i in 0..len {
239                out[i] = op(a[i], scalar);
240            }
241        }
242    }
243
244    result
245}
246
247/// Generic comparison operation
248#[inline]
249fn simd_cmp_op<F>(a_ptr: *const f64, b_ptr: *const f64, len: usize, op: F) -> *mut f64
250where
251    F: Fn(f64, f64) -> bool,
252{
253    if a_ptr.is_null() || b_ptr.is_null() || len == 0 {
254        return std::ptr::null_mut();
255    }
256
257    let result = alloc_f64_buffer(len);
258    if result.is_null() {
259        return std::ptr::null_mut();
260    }
261
262    unsafe {
263        let a = std::slice::from_raw_parts(a_ptr, len);
264        let b = std::slice::from_raw_parts(b_ptr, len);
265        let out = std::slice::from_raw_parts_mut(result, len);
266
267        if len >= SIMD_THRESHOLD {
268            // Process 4 elements at a time
269            let chunks = len / 4;
270            for i in 0..chunks {
271                let idx = i * 4;
272                out[idx] = if op(a[idx], b[idx]) { 1.0 } else { 0.0 };
273                out[idx + 1] = if op(a[idx + 1], b[idx + 1]) { 1.0 } else { 0.0 };
274                out[idx + 2] = if op(a[idx + 2], b[idx + 2]) { 1.0 } else { 0.0 };
275                out[idx + 3] = if op(a[idx + 3], b[idx + 3]) { 1.0 } else { 0.0 };
276            }
277            // Handle remainder
278            for i in (chunks * 4)..len {
279                out[i] = if op(a[i], b[i]) { 1.0 } else { 0.0 };
280            }
281        } else {
282            // Scalar fallback
283            for i in 0..len {
284                out[i] = if op(a[i], b[i]) { 1.0 } else { 0.0 };
285            }
286        }
287    }
288
289    result
290}
291
292/// Free a SIMD result buffer allocated by jit_simd_* functions
293#[unsafe(no_mangle)]
294pub extern "C" fn jit_simd_free(ptr: *mut f64, len: u64) {
295    if ptr.is_null() || len == 0 {
296        return;
297    }
298    let layout = Layout::from_size_align(len as usize * std::mem::size_of::<f64>(), 32)
299        .expect("Invalid layout");
300    unsafe {
301        std::alloc::dealloc(ptr as *mut u8, layout);
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_simd_add() {
311        let a = vec![1.0, 2.0, 3.0, 4.0];
312        let b = vec![10.0, 20.0, 30.0, 40.0];
313        let result = jit_simd_add(a.as_ptr(), b.as_ptr(), 4);
314
315        unsafe {
316            assert_eq!(*result, 11.0);
317            assert_eq!(*result.add(1), 22.0);
318            assert_eq!(*result.add(2), 33.0);
319            assert_eq!(*result.add(3), 44.0);
320        }
321        jit_simd_free(result, 4);
322    }
323
324    #[test]
325    fn test_simd_mul_large() {
326        let len = 1000;
327        let a: Vec<f64> = (0..len).map(|i| i as f64).collect();
328        let b: Vec<f64> = (0..len).map(|i| (i * 2) as f64).collect();
329        let result = jit_simd_mul(a.as_ptr(), b.as_ptr(), len as u64);
330
331        unsafe {
332            for i in 0..len {
333                assert_eq!(*result.add(i), (i * i * 2) as f64);
334            }
335        }
336        jit_simd_free(result, len as u64);
337    }
338
339    #[test]
340    fn test_simd_gt() {
341        let a = vec![5.0, 2.0, 8.0, 1.0];
342        let b = vec![3.0, 4.0, 8.0, 0.0];
343        let result = jit_simd_gt(a.as_ptr(), b.as_ptr(), 4);
344
345        unsafe {
346            assert_eq!(*result, 1.0); // 5 > 3
347            assert_eq!(*result.add(1), 0.0); // 2 > 4 = false
348            assert_eq!(*result.add(2), 0.0); // 8 > 8 = false
349            assert_eq!(*result.add(3), 1.0); // 1 > 0
350        }
351        jit_simd_free(result, 4);
352    }
353}