Skip to main content

sklears_simd/
optimization_hints.rs

1//! Compile-time optimization hints for SIMD operations
2//!
3//! This module provides compiler hints and attributes to help the compiler
4//! optimize SIMD operations more effectively.
5
6#[cfg(feature = "no-std")]
7use core::{hint, mem::size_of};
8#[cfg(not(feature = "no-std"))]
9use std::{hint, mem::size_of};
10
11/// Compile-time optimization hints for SIMD operations
12pub struct OptimizationHints;
13
14impl OptimizationHints {
15    /// Hint to the compiler that a branch is likely to be taken
16    #[inline(always)]
17    pub fn likely(b: bool) -> bool {
18        // Use intrinsic hint when available
19        #[cfg(target_arch = "x86_64")]
20        {
21            if b {
22                unsafe { core::arch::x86_64::_mm_prefetch::<0>(core::ptr::null::<i8>()) };
23            }
24        }
25        b
26    }
27
28    /// Hint to the compiler that a branch is unlikely to be taken
29    #[inline(always)]
30    pub fn unlikely(b: bool) -> bool {
31        // Inverse of likely
32        !Self::likely(!b)
33    }
34
35    /// Hint that a pointer is aligned to SIMD boundaries
36    #[inline(always)]
37    pub fn assume_aligned<T>(ptr: *const T, align: usize) -> *const T {
38        if align.is_power_of_two() && align >= size_of::<T>() {
39            // Compiler hint for alignment
40            unsafe { core::ptr::addr_of!(*ptr.cast::<u8>().add(0).cast::<T>()) }
41        } else {
42            ptr
43        }
44    }
45
46    /// Hint that a pointer is aligned to SIMD boundaries (mutable)
47    #[inline(always)]
48    pub fn assume_aligned_mut<T>(ptr: *mut T, align: usize) -> *mut T {
49        if align.is_power_of_two() && align >= size_of::<T>() {
50            // Compiler hint for alignment
51            unsafe { core::ptr::addr_of_mut!(*ptr.cast::<u8>().add(0).cast::<T>()) }
52        } else {
53            ptr
54        }
55    }
56
57    /// Hint that a value is within a specific range
58    #[inline(always)]
59    pub fn assume_range<T: PartialOrd + Copy>(value: T, min: T, max: T) -> T {
60        if value >= min && value <= max {
61            value
62        } else {
63            // Undefined behavior if assumption is false
64            unsafe { hint::unreachable_unchecked() }
65        }
66    }
67
68    /// Hint that a slice has a specific length
69    #[inline(always)]
70    pub fn assume_len<T>(slice: &[T], len: usize) -> &[T] {
71        if slice.len() == len {
72            slice
73        } else {
74            // Undefined behavior if assumption is false
75            unsafe { hint::unreachable_unchecked() }
76        }
77    }
78
79    /// Hint that a slice has a specific length (mutable)
80    #[inline(always)]
81    pub fn assume_len_mut<T>(slice: &mut [T], len: usize) -> &mut [T] {
82        if slice.len() == len {
83            slice
84        } else {
85            // Undefined behavior if assumption is false
86            unsafe { hint::unreachable_unchecked() }
87        }
88    }
89
90    /// Hint that a loop will iterate a specific number of times
91    #[inline(always)]
92    pub fn assume_loop_count(count: usize) -> usize {
93        // Compiler hint for loop unrolling
94        if count > 0 {
95            count
96        } else {
97            0
98        }
99    }
100
101    /// Hint that data is hot (frequently accessed)
102    #[inline(always)]
103    pub fn prefetch_read<T>(_ptr: *const T) {
104        #[cfg(target_arch = "x86_64")]
105        {
106            unsafe { core::arch::x86_64::_mm_prefetch::<3>(_ptr as *const i8) };
107        }
108        // AArch64 prefetch requires unstable features - disabled for stable Rust
109        // #[cfg(all(target_arch = "aarch64", feature = "nightly"))]
110        // {
111        //     unsafe { std::arch::aarch64::_prefetch(_ptr as *const i8, 0, 3) };
112        // }
113    }
114
115    /// Hint that data will be written to (for write prefetching)
116    #[inline(always)]
117    pub fn prefetch_write<T>(_ptr: *const T) {
118        #[cfg(target_arch = "x86_64")]
119        {
120            unsafe { core::arch::x86_64::_mm_prefetch::<1>(_ptr as *const i8) };
121        }
122        // AArch64 prefetch requires unstable features - disabled for stable Rust
123        // #[cfg(all(target_arch = "aarch64", feature = "nightly"))]
124        // {
125        //     unsafe { std::arch::aarch64::_prefetch(_ptr as *const i8, 1, 3) };
126        // }
127    }
128
129    /// Hint that memory access will be non-temporal
130    #[inline(always)]
131    pub fn prefetch_nta<T>(_ptr: *const T) {
132        #[cfg(target_arch = "x86_64")]
133        {
134            unsafe { core::arch::x86_64::_mm_prefetch::<0>(_ptr as *const i8) };
135        }
136    }
137
138    /// Hint for vectorization - assume no aliasing
139    #[inline(always)]
140    pub fn assume_noalias<T>(ptr1: *const T, ptr2: *const T, len: usize) -> bool {
141        let range1 = ptr1 as usize..ptr1 as usize + len * size_of::<T>();
142        let range2 = ptr2 as usize..ptr2 as usize + len * size_of::<T>();
143        !range1.contains(&range2.start) && !range2.contains(&range1.start)
144    }
145
146    /// Hint for SIMD width optimization
147    #[inline(always)]
148    pub fn optimal_simd_width<T>() -> usize {
149        // Get optimal SIMD width based on type and architecture
150        match size_of::<T>() {
151            1 => 64, // 64 bytes for u8/i8
152            2 => 32, // 32 elements for u16/i16
153            4 => 16, // 16 elements for u32/i32/f32
154            8 => 8,  // 8 elements for u64/i64/f64
155            _ => 4,  // Default fallback
156        }
157    }
158}
159
160/// Macro for compile-time optimization hints
161#[macro_export]
162macro_rules! optimize_for_simd {
163    (likely($expr:expr)) => {
164        $crate::optimization_hints::OptimizationHints::likely($expr)
165    };
166    (unlikely($expr:expr)) => {
167        $crate::optimization_hints::OptimizationHints::unlikely($expr)
168    };
169    (assume_aligned($ptr:expr, $align:expr)) => {
170        $crate::optimization_hints::OptimizationHints::assume_aligned($ptr, $align)
171    };
172    (assume_len($slice:expr, $len:expr)) => {
173        $crate::optimization_hints::OptimizationHints::assume_len($slice, $len)
174    };
175    (prefetch_read($ptr:expr)) => {
176        $crate::optimization_hints::OptimizationHints::prefetch_read($ptr)
177    };
178    (prefetch_write($ptr:expr)) => {
179        $crate::optimization_hints::OptimizationHints::prefetch_write($ptr)
180    };
181}
182
183/// Compiler attributes for SIMD optimization
184pub mod attributes {
185    /// Force inlining for SIMD operations
186    pub const FORCE_INLINE: &str = "inline(always)";
187
188    /// Never inline (for larger functions)
189    pub const NEVER_INLINE: &str = "inline(never)";
190
191    /// Target-specific optimization
192    pub const TARGET_FEATURE: &str = "target_feature";
193
194    /// Cold code (rarely executed)
195    pub const COLD: &str = "cold";
196
197    /// Hot code (frequently executed)
198    pub const HOT: &str = "hot";
199
200    /// No mangle (for C FFI)
201    pub const NO_MANGLE: &str = "no_mangle";
202
203    /// Repr C (for C compatibility)
204    pub const REPR_C: &str = "repr(C)";
205
206    /// Repr align (for SIMD alignment)
207    pub const REPR_ALIGN: &str = "repr(align)";
208}
209
210/// SIMD-specific compiler hints
211pub mod simd_hints {
212    use super::OptimizationHints;
213
214    /// Hint that arrays are SIMD-aligned
215    #[inline(always)]
216    pub fn assume_simd_aligned<T>(slice: &[T]) -> &[T] {
217        let align = if cfg!(target_feature = "avx512f") {
218            64
219        } else if cfg!(target_feature = "avx2") {
220            32
221        } else {
222            16
223        };
224
225        let ptr = OptimizationHints::assume_aligned(slice.as_ptr(), align);
226        unsafe { core::slice::from_raw_parts(ptr, slice.len()) }
227    }
228
229    /// Hint that arrays are SIMD-aligned (mutable)
230    #[inline(always)]
231    pub fn assume_simd_aligned_mut<T>(slice: &mut [T]) -> &mut [T] {
232        let align = if cfg!(target_feature = "avx512f") {
233            64
234        } else if cfg!(target_feature = "avx2") {
235            32
236        } else {
237            16
238        };
239
240        let ptr = OptimizationHints::assume_aligned_mut(slice.as_mut_ptr(), align);
241        unsafe { core::slice::from_raw_parts_mut(ptr, slice.len()) }
242    }
243
244    /// Hint that loop will vectorize
245    #[inline(always)]
246    pub fn assume_vectorizable<T, F>(slice: &[T], mut f: F)
247    where
248        F: FnMut(&T),
249    {
250        let len = OptimizationHints::assume_loop_count(slice.len());
251        for item in slice.iter().take(len) {
252            f(item);
253        }
254    }
255
256    /// Hint that parallel processing is beneficial
257    #[inline(always)]
258    pub fn assume_parallel_beneficial(size: usize) -> bool {
259        OptimizationHints::likely(size > 1000)
260    }
261
262    /// Hint for optimal chunk size
263    #[inline(always)]
264    pub fn optimal_chunk_size<T>() -> usize {
265        OptimizationHints::optimal_simd_width::<T>() * 4
266    }
267}
268
269#[allow(non_snake_case)]
270#[cfg(all(test, not(feature = "no-std")))]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_optimization_hints() {
276        let ptr = [1.0f32; 16].as_ptr();
277        let aligned_ptr = OptimizationHints::assume_aligned(ptr, 16);
278        assert_eq!(ptr, aligned_ptr);
279
280        let slice = &[1, 2, 3, 4];
281        let len_slice = OptimizationHints::assume_len(slice, 4);
282        assert_eq!(slice.len(), len_slice.len());
283
284        let optimal_width = OptimizationHints::optimal_simd_width::<f32>();
285        assert!(optimal_width > 0);
286    }
287
288    #[test]
289    fn test_simd_hints() {
290        let data = vec![1.0f32; 64];
291        let aligned_slice = simd_hints::assume_simd_aligned(&data);
292        assert_eq!(data.len(), aligned_slice.len());
293
294        let chunk_size = simd_hints::optimal_chunk_size::<f32>();
295        assert!(chunk_size > 0);
296
297        let parallel = simd_hints::assume_parallel_beneficial(2000);
298        assert!(parallel);
299    }
300
301    #[test]
302    fn test_branch_hints() {
303        let likely_true = OptimizationHints::likely(true);
304        let unlikely_false = OptimizationHints::unlikely(false);
305
306        assert!(likely_true);
307        assert!(!unlikely_false);
308    }
309
310    #[test]
311    fn test_prefetch_hints() {
312        let data = vec![1.0f32; 100];
313        OptimizationHints::prefetch_read(data.as_ptr());
314        OptimizationHints::prefetch_write(data.as_ptr());
315        OptimizationHints::prefetch_nta(data.as_ptr());
316
317        // If we get here, prefetch calls didn't crash
318        // (no assertion needed)
319    }
320
321    #[test]
322    fn test_macro_hints() {
323        let data = vec![1.0f32; 16];
324        let ptr = optimize_for_simd!(assume_aligned(data.as_ptr(), 16));
325
326        optimize_for_simd!(prefetch_read(ptr));
327
328        let slice = optimize_for_simd!(assume_len(data.as_slice(), 16));
329        assert_eq!(slice.len(), 16);
330    }
331}