sklears_simd/
optimization_hints.rs1#[cfg(feature = "no-std")]
7use core::{hint, mem::size_of};
8#[cfg(not(feature = "no-std"))]
9use std::{hint, mem::size_of};
10
11pub struct OptimizationHints;
13
14impl OptimizationHints {
15 #[inline(always)]
17 pub fn likely(b: bool) -> bool {
18 #[cfg(target_arch = "x86_64")]
20 {
21 if b {
22 unsafe { core::arch::x86_64::_mm_prefetch::<0>(core::ptr::null::<i8>()) };
23 }
24 }
25 b
26 }
27
28 #[inline(always)]
30 pub fn unlikely(b: bool) -> bool {
31 !Self::likely(!b)
33 }
34
35 #[inline(always)]
37 pub fn assume_aligned<T>(ptr: *const T, align: usize) -> *const T {
38 if align.is_power_of_two() && align >= size_of::<T>() {
39 unsafe { core::ptr::addr_of!(*ptr.cast::<u8>().add(0).cast::<T>()) }
41 } else {
42 ptr
43 }
44 }
45
46 #[inline(always)]
48 pub fn assume_aligned_mut<T>(ptr: *mut T, align: usize) -> *mut T {
49 if align.is_power_of_two() && align >= size_of::<T>() {
50 unsafe { core::ptr::addr_of_mut!(*ptr.cast::<u8>().add(0).cast::<T>()) }
52 } else {
53 ptr
54 }
55 }
56
57 #[inline(always)]
59 pub fn assume_range<T: PartialOrd + Copy>(value: T, min: T, max: T) -> T {
60 if value >= min && value <= max {
61 value
62 } else {
63 unsafe { hint::unreachable_unchecked() }
65 }
66 }
67
68 #[inline(always)]
70 pub fn assume_len<T>(slice: &[T], len: usize) -> &[T] {
71 if slice.len() == len {
72 slice
73 } else {
74 unsafe { hint::unreachable_unchecked() }
76 }
77 }
78
79 #[inline(always)]
81 pub fn assume_len_mut<T>(slice: &mut [T], len: usize) -> &mut [T] {
82 if slice.len() == len {
83 slice
84 } else {
85 unsafe { hint::unreachable_unchecked() }
87 }
88 }
89
90 #[inline(always)]
92 pub fn assume_loop_count(count: usize) -> usize {
93 if count > 0 {
95 count
96 } else {
97 0
98 }
99 }
100
101 #[inline(always)]
103 pub fn prefetch_read<T>(_ptr: *const T) {
104 #[cfg(target_arch = "x86_64")]
105 {
106 unsafe { core::arch::x86_64::_mm_prefetch::<3>(_ptr as *const i8) };
107 }
108 }
114
115 #[inline(always)]
117 pub fn prefetch_write<T>(_ptr: *const T) {
118 #[cfg(target_arch = "x86_64")]
119 {
120 unsafe { core::arch::x86_64::_mm_prefetch::<1>(_ptr as *const i8) };
121 }
122 }
128
129 #[inline(always)]
131 pub fn prefetch_nta<T>(_ptr: *const T) {
132 #[cfg(target_arch = "x86_64")]
133 {
134 unsafe { core::arch::x86_64::_mm_prefetch::<0>(_ptr as *const i8) };
135 }
136 }
137
138 #[inline(always)]
140 pub fn assume_noalias<T>(ptr1: *const T, ptr2: *const T, len: usize) -> bool {
141 let range1 = ptr1 as usize..ptr1 as usize + len * size_of::<T>();
142 let range2 = ptr2 as usize..ptr2 as usize + len * size_of::<T>();
143 !range1.contains(&range2.start) && !range2.contains(&range1.start)
144 }
145
146 #[inline(always)]
148 pub fn optimal_simd_width<T>() -> usize {
149 match size_of::<T>() {
151 1 => 64, 2 => 32, 4 => 16, 8 => 8, _ => 4, }
157 }
158}
159
160#[macro_export]
162macro_rules! optimize_for_simd {
163 (likely($expr:expr)) => {
164 $crate::optimization_hints::OptimizationHints::likely($expr)
165 };
166 (unlikely($expr:expr)) => {
167 $crate::optimization_hints::OptimizationHints::unlikely($expr)
168 };
169 (assume_aligned($ptr:expr, $align:expr)) => {
170 $crate::optimization_hints::OptimizationHints::assume_aligned($ptr, $align)
171 };
172 (assume_len($slice:expr, $len:expr)) => {
173 $crate::optimization_hints::OptimizationHints::assume_len($slice, $len)
174 };
175 (prefetch_read($ptr:expr)) => {
176 $crate::optimization_hints::OptimizationHints::prefetch_read($ptr)
177 };
178 (prefetch_write($ptr:expr)) => {
179 $crate::optimization_hints::OptimizationHints::prefetch_write($ptr)
180 };
181}
182
183pub mod attributes {
185 pub const FORCE_INLINE: &str = "inline(always)";
187
188 pub const NEVER_INLINE: &str = "inline(never)";
190
191 pub const TARGET_FEATURE: &str = "target_feature";
193
194 pub const COLD: &str = "cold";
196
197 pub const HOT: &str = "hot";
199
200 pub const NO_MANGLE: &str = "no_mangle";
202
203 pub const REPR_C: &str = "repr(C)";
205
206 pub const REPR_ALIGN: &str = "repr(align)";
208}
209
210pub mod simd_hints {
212 use super::OptimizationHints;
213
214 #[inline(always)]
216 pub fn assume_simd_aligned<T>(slice: &[T]) -> &[T] {
217 let align = if cfg!(target_feature = "avx512f") {
218 64
219 } else if cfg!(target_feature = "avx2") {
220 32
221 } else {
222 16
223 };
224
225 let ptr = OptimizationHints::assume_aligned(slice.as_ptr(), align);
226 unsafe { core::slice::from_raw_parts(ptr, slice.len()) }
227 }
228
229 #[inline(always)]
231 pub fn assume_simd_aligned_mut<T>(slice: &mut [T]) -> &mut [T] {
232 let align = if cfg!(target_feature = "avx512f") {
233 64
234 } else if cfg!(target_feature = "avx2") {
235 32
236 } else {
237 16
238 };
239
240 let ptr = OptimizationHints::assume_aligned_mut(slice.as_mut_ptr(), align);
241 unsafe { core::slice::from_raw_parts_mut(ptr, slice.len()) }
242 }
243
244 #[inline(always)]
246 pub fn assume_vectorizable<T, F>(slice: &[T], mut f: F)
247 where
248 F: FnMut(&T),
249 {
250 let len = OptimizationHints::assume_loop_count(slice.len());
251 for item in slice.iter().take(len) {
252 f(item);
253 }
254 }
255
256 #[inline(always)]
258 pub fn assume_parallel_beneficial(size: usize) -> bool {
259 OptimizationHints::likely(size > 1000)
260 }
261
262 #[inline(always)]
264 pub fn optimal_chunk_size<T>() -> usize {
265 OptimizationHints::optimal_simd_width::<T>() * 4
266 }
267}
268
269#[allow(non_snake_case)]
270#[cfg(all(test, not(feature = "no-std")))]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_optimization_hints() {
276 let ptr = [1.0f32; 16].as_ptr();
277 let aligned_ptr = OptimizationHints::assume_aligned(ptr, 16);
278 assert_eq!(ptr, aligned_ptr);
279
280 let slice = &[1, 2, 3, 4];
281 let len_slice = OptimizationHints::assume_len(slice, 4);
282 assert_eq!(slice.len(), len_slice.len());
283
284 let optimal_width = OptimizationHints::optimal_simd_width::<f32>();
285 assert!(optimal_width > 0);
286 }
287
288 #[test]
289 fn test_simd_hints() {
290 let data = vec![1.0f32; 64];
291 let aligned_slice = simd_hints::assume_simd_aligned(&data);
292 assert_eq!(data.len(), aligned_slice.len());
293
294 let chunk_size = simd_hints::optimal_chunk_size::<f32>();
295 assert!(chunk_size > 0);
296
297 let parallel = simd_hints::assume_parallel_beneficial(2000);
298 assert!(parallel);
299 }
300
301 #[test]
302 fn test_branch_hints() {
303 let likely_true = OptimizationHints::likely(true);
304 let unlikely_false = OptimizationHints::unlikely(false);
305
306 assert!(likely_true);
307 assert!(!unlikely_false);
308 }
309
310 #[test]
311 fn test_prefetch_hints() {
312 let data = vec![1.0f32; 100];
313 OptimizationHints::prefetch_read(data.as_ptr());
314 OptimizationHints::prefetch_write(data.as_ptr());
315 OptimizationHints::prefetch_nta(data.as_ptr());
316
317 }
320
321 #[test]
322 fn test_macro_hints() {
323 let data = vec![1.0f32; 16];
324 let ptr = optimize_for_simd!(assume_aligned(data.as_ptr(), 16));
325
326 optimize_for_simd!(prefetch_read(ptr));
327
328 let slice = optimize_for_simd!(assume_len(data.as_slice(), 16));
329 assert_eq!(slice.len(), 16);
330 }
331}