Skip to main content

simdeez/
invoking.rs

1use crate::{engines, Simd};
2
3#[cfg(target_arch = "aarch64")]
4use std::arch::is_aarch64_feature_detected;
5
6#[macro_export]
7macro_rules! fix_tuple_type {
8    (()) => {
9        ()
10    };
11    (($typ:ty)) => {
12        ($typ,)
13    };
14    (($($typ:ty),*)) => {
15        (($($typ),*))
16    };
17}
18
19#[macro_export]
20macro_rules! __simd_generate_base {
21    ($(#[$meta:meta])* $vis:vis fn $fn_name:ident $(<$($lt:lifetime),+>)? ($($arg:ident:$typ:ty),* ) -> $rt:ty $body:block  ) => {
22        simdeez_paste_item! {
23            // In order to pass arguments via generics like this, we need to convert the arguments
24            // into tuples. This is part of the reason for the mess below.
25
26            #[inline(always)]
27            $vis unsafe fn [<__ $fn_name _generic>]<$($($lt,)+)? S: 'static + Simd>(args_tuple: ($($typ,)*)) -> $rt {
28                let ($($arg,)*) = args_tuple;
29                S::invoke(#[inline(always)] || $body)
30            }
31
32            $(#[$meta])*
33            #[inline(always)]
34            $vis fn [<$fn_name _generic>] <$($($lt),+,)? S: Simd>($($arg:$typ,)*) -> $rt {
35                let args_tuple = ($($arg,)*);
36                __run_simd_generic::<S, [<__ $fn_name _dispatch_struct>], fix_tuple_type!(($($typ),*)), $rt>(args_tuple)
37            }
38
39            #[allow(non_camel_case_types)]
40            struct [<__ $fn_name _dispatch_struct>];
41
42            impl$(<$($lt),+>)? __SimdRunner<fix_tuple_type!(($($typ),*)), $rt> for [<__ $fn_name _dispatch_struct>] {
43                unsafe fn run<S: Simd>(args_tuple: fix_tuple_type!(($($typ),*))) -> $rt {
44                    [<__ $fn_name _generic>]::<S>(args_tuple)
45                }
46            }
47        }
48    };
49}
50
51#[macro_export]
52macro_rules! simd_runtime_generate {
53    ($(#[$meta:meta])* $vis:vis fn $fn_name:ident $(<$($lt:lifetime),+>)? ($($arg:ident:$typ:ty),* $(,)? ) -> $rt:ty $body:block  ) => {
54        simdeez_paste_item! {
55            // In order to pass arguments via generics like this, we need to convert the arguments
56            // into tuples. This is part of the reason for the mess below.
57
58            $(#[$meta])*
59            #[inline(always)]
60            $vis fn $fn_name $(<$($lt),+>)?($($arg:$typ,)*) -> $rt {
61                let args_tuple = ($($arg,)*);
62                __run_simd_runtime_decide::<[<__ $fn_name _dispatch_struct>], fix_tuple_type!(($($typ),*)), $rt>(args_tuple)
63            }
64
65            $(#[$meta])*
66            #[inline(always)]
67            $vis fn [<$fn_name _scalar>] $(<$($lt),+>)?($($arg:$typ,)*) -> $rt {
68                let args_tuple = ($($arg,)*);
69                __run_simd_invoke_scalar::<[<__ $fn_name _dispatch_struct>], fix_tuple_type!(($($typ),*)), $rt>(args_tuple)
70            }
71
72            __simd_generate_base!($(#[$meta])* $vis fn $fn_name $(<$($lt),+>)? ($($arg:$typ),* ) -> $rt $body);
73        }
74    };
75    ($(#[$meta:meta])* $vis:vis fn $fn_name:ident ($($arg:ident:$typ:ty),* $(,)? ) $body:block  ) => {
76        simd_runtime_generate!($(#[$meta])* $vis fn $fn_name ($($arg:$typ),*) -> () $body);
77    };
78}
79
80#[macro_export]
81macro_rules! simd_compiletime_select {
82    ($(#[$meta:meta])* $vis:vis fn $fn_name:ident $(<$($lt:lifetime),+>)? ($($arg:ident:$typ:ty),* $(,)? ) -> $rt:ty $body:block  ) => {
83        simdeez_paste_item! {
84            $(#[$meta])*
85            #[inline(always)]
86            $vis fn $fn_name $(<$($lt),+>)?($($arg:$typ,)*) -> $rt {
87                let args_tuple = ($($arg,)*);
88                __run_simd_compiletime_select::<[<__ $fn_name _dispatch_struct>], fix_tuple_type!(($($typ),*)), $rt>(args_tuple)
89            }
90
91            $(#[$meta])*
92            #[inline(always)]
93            $vis fn [<$fn_name _scalar>] $(<$($lt),+>)?($($arg:$typ,)*) -> $rt {
94                let args_tuple = ($($arg,)*);
95                __run_simd_invoke_scalar::<[<__ $fn_name _dispatch_struct>], fix_tuple_type!(($($typ),*)), $rt>(args_tuple)
96            }
97
98            __simd_generate_base!($(#[$meta])* $vis fn $fn_name $(<$($lt),+>)? ($($arg:$typ),* ) -> $rt $body);
99        }
100    };
101    ($(#[$meta:meta])* $vis:vis fn $fn_name:ident ($($arg:ident:$typ:ty),* $(,)? ) $body:block  ) => {
102        simd_compiletime_select!($(#[$meta])* $vis fn $fn_name ($($arg:$typ),*) -> () $body);
103    };
104}
105
106#[macro_export]
107macro_rules! simd_unsafe_generate_all {
108    ($(#[$meta:meta])* $vis:vis fn $fn_name:ident $(<$($lt:lifetime),+>)? ($($arg:ident:$typ:ty),* $(,)? ) -> $rt:ty $body:block  ) => {
109        simdeez_paste_item! {
110            $(#[$meta])*
111            #[inline(always)]
112            $vis fn $fn_name $(<$($lt),+>)?($($arg:$typ,)*) -> $rt {
113                let args_tuple = ($($arg,)*);
114                __run_simd_runtime_decide::<[<__ $fn_name _dispatch_struct>], fix_tuple_type!(($($typ),*)), $rt>(args_tuple)
115            }
116
117            $(#[$meta])*
118            #[inline(always)]
119            #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
120            $vis fn [<$fn_name _scalar>] $(<$($lt),+>)?($($arg:$typ,)*) -> $rt {
121                let args_tuple = ($($arg,)*);
122                __run_simd_invoke_scalar::<[<__ $fn_name _dispatch_struct>], fix_tuple_type!(($($typ),*)), $rt>(args_tuple)
123            }
124
125            $(#[$meta])*
126            #[inline(always)]
127            #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
128            $vis unsafe fn [<$fn_name _sse2>] $(<$($lt),+>)?($($arg:$typ,)*) -> $rt {
129                let args_tuple = ($($arg,)*);
130                __run_simd_invoke_sse2::<[<__ $fn_name _dispatch_struct>], fix_tuple_type!(($($typ),*)), $rt>(args_tuple)
131            }
132
133            $(#[$meta])*
134            #[inline(always)]
135            #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
136            $vis unsafe fn [<$fn_name _sse41>] $(<$($lt),+>)?($($arg:$typ,)*) -> $rt {
137                let args_tuple = ($($arg,)*);
138                __run_simd_invoke_sse41::<[<__ $fn_name _dispatch_struct>], fix_tuple_type!(($($typ),*)), $rt>(args_tuple)
139            }
140
141            $(#[$meta])*
142            #[inline(always)]
143            #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
144            $vis unsafe fn [<$fn_name _avx2>] $(<$($lt),+>)?($($arg:$typ,)*) -> $rt {
145                let args_tuple = ($($arg,)*);
146                __run_simd_invoke_avx2::<[<__ $fn_name _dispatch_struct>], fix_tuple_type!(($($typ),*)), $rt>(args_tuple)
147            }
148
149            $(#[$meta])*
150            #[inline(always)]
151            #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
152            $vis unsafe fn [<$fn_name _avx512>] $(<$($lt),+>)?($($arg:$typ,)*) -> $rt {
153                let args_tuple = ($($arg,)*);
154                __run_simd_invoke_avx512::<[<__ $fn_name _dispatch_struct>], fix_tuple_type!(($($typ),*)), $rt>(args_tuple)
155            }
156
157            $(#[$meta])*
158            #[inline(always)]
159            #[cfg(target_arch = "aarch64")]
160            $vis unsafe fn [<$fn_name _neon>] $(<$($lt),+>)?($($arg:$typ,)*) -> $rt {
161                let args_tuple = ($($arg,)*);
162                __run_simd_invoke_neon::<[<__ $fn_name _dispatch_struct>], fix_tuple_type!(($($typ),*)), $rt>(args_tuple)
163            }
164
165            $(#[$meta])*
166            #[inline(always)]
167            #[cfg(target_arch = "wasm32")]
168            $vis unsafe fn [<$fn_name _wasm>] $(<$($lt),+>)?($($arg:$typ,)*) -> $rt {
169                let args_tuple = ($($arg,)*);
170                __run_simd_invoke_wasm::<[<__ $fn_name _dispatch_struct>], fix_tuple_type!(($($typ),*)), $rt>(args_tuple)
171            }
172
173            __simd_generate_base!($(#[$meta])* $vis fn $fn_name $(<$($lt),+>)? ($($arg:$typ),* ) -> $rt $body);
174        }
175    };
176    ($(#[$meta:meta])* $vis:vis fn $fn_name:ident ($($arg:ident:$typ:ty),* $(,)? ) $body:block  ) => {
177        simd_unsafe_generate_all!($(#[$meta])* $vis fn $fn_name ($($arg:$typ),*) -> () $body);
178    };
179}
180
181pub trait __SimdRunner<A, R> {
182    unsafe fn run<S: Simd>(args: A) -> R;
183}
184
185#[inline(always)]
186pub fn __run_simd_runtime_decide<S: __SimdRunner<A, R>, A, R>(args: A) -> R {
187    #![allow(unreachable_code)]
188
189    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
190    {
191        if is_x86_feature_detected!("avx512f")
192            && is_x86_feature_detected!("avx512bw")
193            && is_x86_feature_detected!("avx512dq")
194        {
195            return unsafe { S::run::<engines::avx512::Avx512>(args) };
196        }
197
198        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
199            return unsafe { S::run::<engines::avx2::Avx2>(args) };
200        }
201
202        if is_x86_feature_detected!("sse4.1") {
203            return unsafe { S::run::<engines::sse41::Sse41>(args) };
204        }
205
206        if is_x86_feature_detected!("sse2") {
207            return unsafe { S::run::<engines::sse2::Sse2>(args) };
208        }
209    }
210
211    #[cfg(target_arch = "aarch64")]
212    if is_aarch64_feature_detected!("neon") {
213        return unsafe { S::run::<engines::neon::Neon>(args) };
214    }
215
216    #[cfg(target_arch = "wasm32")]
217    {
218        // Note: there's currently no way to detect SIMD support in WebAssembly at runtime
219        return unsafe { S::run::<engines::wasm32::Wasm>(args) };
220    }
221
222    unsafe { S::run::<engines::scalar::Scalar>(args) }
223}
224
225#[inline(always)]
226pub fn __run_simd_generic<E: Simd, S: __SimdRunner<A, R>, A, R>(args: A) -> R {
227    unsafe { S::run::<E>(args) }
228}
229
230#[inline(always)]
231pub fn __run_simd_compiletime_select<S: __SimdRunner<A, R>, A, R>(args: A) -> R {
232    #![allow(unreachable_code)]
233    #![allow(clippy::needless_return)]
234
235    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
236    {
237        #[cfg(all(
238            target_feature = "avx512f",
239            target_feature = "avx512bw",
240            target_feature = "avx512dq"
241        ))]
242        return unsafe { S::run::<engines::avx512::Avx512>(args) };
243
244        #[cfg(all(target_feature = "avx2", target_feature = "fma"))]
245        return unsafe { S::run::<engines::avx2::Avx2>(args) };
246
247        #[cfg(target_feature = "sse4.1")]
248        return unsafe { S::run::<engines::sse41::Sse41>(args) };
249
250        #[cfg(target_feature = "sse2")]
251        return unsafe { S::run::<engines::sse2::Sse2>(args) };
252    }
253
254    #[cfg(target_arch = "aarch64")]
255    {
256        #[cfg(target_feature = "neon")]
257        return unsafe { S::run::<engines::neon::Neon>(args) };
258    }
259
260    #[cfg(target_arch = "wasm32")]
261    {
262        return unsafe { S::run::<engines::wasm32::Wasm>(args) };
263    }
264
265    return unsafe { S::run::<engines::scalar::Scalar>(args) };
266}
267
268#[inline(always)]
269pub fn __run_simd_invoke_scalar<S: __SimdRunner<A, R>, A, R>(args: A) -> R {
270    unsafe { S::run::<engines::scalar::Scalar>(args) }
271}
272
273#[inline(always)]
274#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
275pub unsafe fn __run_simd_invoke_sse2<S: __SimdRunner<A, R>, A, R>(args: A) -> R {
276    unsafe { S::run::<engines::sse2::Sse2>(args) }
277}
278
279#[inline(always)]
280#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
281pub unsafe fn __run_simd_invoke_sse41<S: __SimdRunner<A, R>, A, R>(args: A) -> R {
282    unsafe { S::run::<engines::sse41::Sse41>(args) }
283}
284
285#[inline(always)]
286#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
287pub unsafe fn __run_simd_invoke_avx2<S: __SimdRunner<A, R>, A, R>(args: A) -> R {
288    unsafe { S::run::<engines::avx2::Avx2>(args) }
289}
290
291#[inline(always)]
292#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
293pub unsafe fn __run_simd_invoke_avx512<S: __SimdRunner<A, R>, A, R>(args: A) -> R {
294    unsafe { S::run::<engines::avx512::Avx512>(args) }
295}
296
297#[inline(always)]
298#[cfg(target_feature = "neon")]
299pub unsafe fn __run_simd_invoke_neon<S: __SimdRunner<A, R>, A, R>(args: A) -> R {
300    unsafe { S::run::<engines::neon::Neon>(args) }
301}
302
303#[inline(always)]
304#[cfg(target_arch = "wasm32")]
305pub unsafe fn __run_simd_invoke_wasm<S: __SimdRunner<A, R>, A, R>(args: A) -> R {
306    unsafe { S::run::<engines::wasm32::Wasm>(args) }
307}
308
309#[macro_export]
310macro_rules! simd_invoke {
311    ($g:ident, $($r:tt)+) => {
312        $g::invoke(
313            #[inline(always)]
314            || $($r)+
315        )
316    }
317}