Skip to main content

scirs2_core/
simd_aligned.rs

1//! Memory-aligned SIMD operations for optimal performance
2//!
3//! This module provides SIMD operations that work with properly aligned memory
4//! for maximum performance. These operations are designed to be faster than
5//! the general SIMD operations when you control the memory layout.
6
7use std::alloc::{alloc, dealloc, Layout};
8use std::ptr::{self, NonNull};
9use std::slice;
10
11/// Memory alignment for SIMD operations (32 bytes for AVX2)
12pub const SIMD_ALIGNMENT: usize = 32;
13
14/// A memory-aligned vector for optimal SIMD performance
15pub struct AlignedVec<T> {
16    ptr: NonNull<T>,
17    len: usize,
18    capacity: usize,
19}
20
21impl<T> AlignedVec<T> {
22    /// Create a new aligned vector with the specified capacity
23    pub fn with_capacity(capacity: usize) -> Result<Self, Box<dyn std::error::Error>> {
24        if capacity == 0 {
25            return Ok(Self {
26                ptr: NonNull::dangling(),
27                len: 0,
28                capacity: 0,
29            });
30        }
31
32        // Handle Zero-Sized Types (ZST) to avoid undefined behavior
33        // When size_of::<T>() == 0, we must not call alloc()
34        if std::mem::size_of::<T>() == 0 {
35            return Ok(Self {
36                ptr: NonNull::dangling(),
37                len: 0,
38                capacity,
39            });
40        }
41
42        let layout = Layout::from_size_align(capacity * std::mem::size_of::<T>(), SIMD_ALIGNMENT)
43            .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
44
45        let ptr = unsafe { alloc(layout) };
46        if ptr.is_null() {
47            return Err("Memory allocation failed".into());
48        }
49
50        Ok(Self {
51            ptr: unsafe { NonNull::new_unchecked(ptr as *mut T) },
52            len: 0,
53            capacity,
54        })
55    }
56
57    /// Create a new aligned vector from an existing vector
58    pub fn from_vec(vec: Vec<T>) -> Result<Self, Box<dyn std::error::Error>>
59    where
60        T: Copy,
61    {
62        let mut aligned = Self::with_capacity(vec.len())?;
63        for item in vec {
64            aligned.push(item);
65        }
66        Ok(aligned)
67    }
68
69    /// Push an element to the vector
70    pub fn push(&mut self, value: T) {
71        if self.len >= self.capacity {
72            panic!("AlignedVec capacity exceeded");
73        }
74
75        unsafe {
76            ptr::write(self.ptr.as_ptr().add(self.len), value);
77        }
78        self.len += 1;
79    }
80
81    /// Get the length of the vector
82    pub fn len(&self) -> usize {
83        self.len
84    }
85
86    /// Check if the vector is empty
87    pub fn is_empty(&self) -> bool {
88        self.len == 0
89    }
90
91    /// Get the capacity of the vector
92    pub fn capacity(&self) -> usize {
93        self.capacity
94    }
95
96    /// Get a slice of the vector
97    pub fn as_slice(&self) -> &[T] {
98        if self.len == 0 {
99            &[]
100        } else {
101            unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
102        }
103    }
104
105    /// Get a mutable slice of the vector
106    pub fn as_mut_slice(&mut self) -> &mut [T] {
107        if self.len == 0 {
108            &mut []
109        } else {
110            unsafe { slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
111        }
112    }
113
114    /// Set an element at the specified index
115    ///
116    /// # Panics
117    /// Panics if `index` is out of bounds
118    pub fn set(&mut self, index: usize, value: T) {
119        assert!(
120            index < self.len,
121            "Index {} out of bounds for length {}",
122            index,
123            self.len
124        );
125        unsafe {
126            ptr::write(self.ptr.as_ptr().add(index), value);
127        }
128    }
129
130    /// Get an element at the specified index
131    ///
132    /// # Panics
133    /// Panics if `index` is out of bounds
134    pub fn get(&self, index: usize) -> &T {
135        assert!(
136            index < self.len,
137            "Index {} out of bounds for length {}",
138            index,
139            self.len
140        );
141        unsafe { &*self.ptr.as_ptr().add(index) }
142    }
143
144    /// Create an uninitialized aligned vector with specified capacity and length
145    ///
146    /// # Safety
147    /// The caller must initialize all elements before reading them
148    pub unsafe fn with_capacity_uninit(
149        capacity: usize,
150    ) -> Result<Self, Box<dyn std::error::Error>> {
151        let mut vec = Self::with_capacity(capacity)?;
152        vec.len = capacity; // Set length without initializing
153        Ok(vec)
154    }
155
156    /// Fill the vector with copies of a value
157    pub fn fill(&mut self, value: T)
158    where
159        T: Copy,
160    {
161        for i in 0..self.len {
162            unsafe {
163                ptr::write(self.ptr.as_ptr().add(i), value);
164            }
165        }
166    }
167
168    /// Clear the vector, dropping all elements
169    pub fn clear(&mut self) {
170        for i in 0..self.len {
171            unsafe {
172                ptr::drop_in_place(self.ptr.as_ptr().add(i));
173            }
174        }
175        self.len = 0;
176    }
177
178    /// Convert to a regular Vec
179    pub fn to_vec(&self) -> Vec<T>
180    where
181        T: Clone,
182    {
183        self.as_slice().to_vec()
184    }
185
186    /// Unsafe method to set the length directly
187    ///
188    /// # Safety
189    /// The caller must ensure that:
190    /// - `new_len` is <= capacity
191    /// - Elements from index 0 to new_len-1 are properly initialized
192    pub unsafe fn set_len(&mut self, new_len: usize) {
193        debug_assert!(new_len <= self.capacity);
194        self.len = new_len;
195    }
196
197    /// Get a mutable pointer to the underlying data
198    pub fn as_mut_ptr(&mut self) -> *mut T {
199        self.ptr.as_ptr()
200    }
201
202    /// Get a pointer to the underlying data
203    pub fn as_ptr(&self) -> *const T {
204        self.ptr.as_ptr()
205    }
206}
207
208impl<T> Drop for AlignedVec<T> {
209    fn drop(&mut self) {
210        if self.capacity != 0 {
211            unsafe {
212                // Drop all elements
213                for i in 0..self.len {
214                    ptr::drop_in_place(self.ptr.as_ptr().add(i));
215                }
216
217                // Deallocate memory
218                let layout = Layout::from_size_align_unchecked(
219                    self.capacity * std::mem::size_of::<T>(),
220                    SIMD_ALIGNMENT,
221                );
222                dealloc(self.ptr.as_ptr() as *mut u8, layout);
223            }
224        }
225    }
226}
227
228unsafe impl<T: Send> Send for AlignedVec<T> {}
229unsafe impl<T: Sync> Sync for AlignedVec<T> {}
230
231/// High-performance SIMD addition for aligned f32 vectors
232pub fn simd_add_aligned_f32(a: &[f32], b: &[f32]) -> Result<AlignedVec<f32>, &'static str> {
233    if a.len() != b.len() {
234        return Err("Arrays must have the same length");
235    }
236
237    let len = a.len();
238    let mut result: AlignedVec<f32> =
239        AlignedVec::with_capacity(len).map_err(|_| "Failed to allocate aligned memory")?;
240
241    #[cfg(target_arch = "x86_64")]
242    {
243        use std::arch::x86_64::*;
244
245        if is_x86_feature_detected!("avx2") {
246            unsafe {
247                let mut i = 0;
248
249                // Process 8 f32s at a time with AVX2
250                while i + 8 <= len {
251                    let a_ptr = a.as_ptr().add(i);
252                    let b_ptr = b.as_ptr().add(i);
253                    let result_ptr = result.ptr.as_ptr().add(i);
254
255                    // Use aligned loads if possible
256                    let a_vec = if (a_ptr as usize) % 32 == 0 {
257                        _mm256_load_ps(a_ptr)
258                    } else {
259                        _mm256_loadu_ps(a_ptr)
260                    };
261
262                    let b_vec = if (b_ptr as usize) % 32 == 0 {
263                        _mm256_load_ps(b_ptr)
264                    } else {
265                        _mm256_loadu_ps(b_ptr)
266                    };
267
268                    let result_vec = _mm256_add_ps(a_vec, b_vec);
269
270                    // Store aligned result
271                    _mm256_store_ps(result_ptr, result_vec);
272
273                    i += 8;
274                }
275
276                // Update length for the SIMD-processed elements
277                result.len = i;
278
279                // Handle remaining elements
280                for j in i..len {
281                    result.push(a[j] + b[j]);
282                }
283            }
284        } else if is_x86_feature_detected!("sse") {
285            unsafe {
286                let mut i = 0;
287
288                // Process 4 f32s at a time with SSE
289                while i + 4 <= len {
290                    let a_ptr = a.as_ptr().add(i);
291                    let b_ptr = b.as_ptr().add(i);
292                    let result_ptr = result.ptr.as_ptr().add(i);
293
294                    let a_vec = if (a_ptr as usize) % 16 == 0 {
295                        _mm_load_ps(a_ptr)
296                    } else {
297                        _mm_loadu_ps(a_ptr)
298                    };
299
300                    let b_vec = if (b_ptr as usize) % 16 == 0 {
301                        _mm_load_ps(b_ptr)
302                    } else {
303                        _mm_loadu_ps(b_ptr)
304                    };
305
306                    let result_vec = _mm_add_ps(a_vec, b_vec);
307                    _mm_store_ps(result_ptr, result_vec);
308
309                    i += 4;
310                }
311
312                result.len = i;
313
314                // Handle remaining elements
315                for j in i..len {
316                    result.push(a[j] + b[j]);
317                }
318            }
319        } else {
320            // Scalar fallback
321            for i in 0..len {
322                result.push(a[i] + b[i]);
323            }
324        }
325    }
326
327    #[cfg(target_arch = "aarch64")]
328    {
329        use std::arch::aarch64::*;
330
331        if std::arch::is_aarch64_feature_detected!("neon") {
332            unsafe {
333                let mut i = 0;
334
335                // Process 4 f32s at a time with NEON
336                while i + 4 <= len {
337                    let a_ptr = a.as_ptr().add(i);
338                    let b_ptr = b.as_ptr().add(i);
339                    let result_ptr = result.ptr.as_ptr().add(i);
340
341                    let a_vec = vld1q_f32(a_ptr);
342                    let b_vec = vld1q_f32(b_ptr);
343                    let result_vec = vaddq_f32(a_vec, b_vec);
344                    vst1q_f32(result_ptr, result_vec);
345
346                    i += 4;
347                }
348
349                result.len = i;
350
351                // Handle remaining elements
352                for j in i..len {
353                    result.push(a[j] + b[j]);
354                }
355            }
356        } else {
357            // Scalar fallback
358            for i in 0..len {
359                result.push(a[i] + b[i]);
360            }
361        }
362    }
363
364    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
365    {
366        // Scalar fallback for other architectures
367        for i in 0..len {
368            result.push(a[i] + b[i]);
369        }
370    }
371
372    Ok(result)
373}
374
375/// High-performance SIMD multiplication for aligned f32 vectors
376pub fn simd_mul_aligned_f32(a: &[f32], b: &[f32]) -> Result<AlignedVec<f32>, &'static str> {
377    if a.len() != b.len() {
378        return Err("Arrays must have the same length");
379    }
380
381    let len = a.len();
382    let mut result: AlignedVec<f32> =
383        AlignedVec::with_capacity(len).map_err(|_| "Failed to allocate aligned memory")?;
384
385    #[cfg(target_arch = "x86_64")]
386    {
387        use std::arch::x86_64::*;
388
389        if is_x86_feature_detected!("avx2") {
390            unsafe {
391                let mut i = 0;
392
393                while i + 8 <= len {
394                    let a_ptr = a.as_ptr().add(i);
395                    let b_ptr = b.as_ptr().add(i);
396                    let result_ptr = result.ptr.as_ptr().add(i);
397
398                    let a_vec = if (a_ptr as usize) % 32 == 0 {
399                        _mm256_load_ps(a_ptr)
400                    } else {
401                        _mm256_loadu_ps(a_ptr)
402                    };
403
404                    let b_vec = if (b_ptr as usize) % 32 == 0 {
405                        _mm256_load_ps(b_ptr)
406                    } else {
407                        _mm256_loadu_ps(b_ptr)
408                    };
409
410                    let result_vec = _mm256_mul_ps(a_vec, b_vec);
411                    _mm256_store_ps(result_ptr, result_vec);
412
413                    i += 8;
414                }
415
416                result.len = i;
417
418                for j in i..len {
419                    result.push(a[j] * b[j]);
420                }
421            }
422        } else {
423            // Fallback
424            for i in 0..len {
425                result.push(a[i] * b[i]);
426            }
427        }
428    }
429
430    #[cfg(not(target_arch = "x86_64"))]
431    {
432        // Fallback for other architectures
433        for i in 0..len {
434            result.push(a[i] * b[i]);
435        }
436    }
437
438    Ok(result)
439}
440
441/// High-performance SIMD dot product for aligned f32 vectors
442pub fn simd_dot_aligned_f32(a: &[f32], b: &[f32]) -> Result<f32, &'static str> {
443    if a.len() != b.len() {
444        return Err("Arrays must have the same length");
445    }
446
447    let len = a.len();
448
449    #[cfg(target_arch = "x86_64")]
450    {
451        use std::arch::x86_64::*;
452
453        if is_x86_feature_detected!("avx2") {
454            unsafe {
455                let mut sums = _mm256_setzero_ps();
456                let mut i = 0;
457
458                while i + 8 <= len {
459                    let a_ptr = a.as_ptr().add(i);
460                    let b_ptr = b.as_ptr().add(i);
461
462                    let a_vec = if (a_ptr as usize) % 32 == 0 {
463                        _mm256_load_ps(a_ptr)
464                    } else {
465                        _mm256_loadu_ps(a_ptr)
466                    };
467
468                    let b_vec = if (b_ptr as usize) % 32 == 0 {
469                        _mm256_load_ps(b_ptr)
470                    } else {
471                        _mm256_loadu_ps(b_ptr)
472                    };
473
474                    let product = _mm256_mul_ps(a_vec, b_vec);
475                    sums = _mm256_add_ps(sums, product);
476
477                    i += 8;
478                }
479
480                // Horizontal sum
481                let high = _mm256_extractf128_ps(sums, 1);
482                let low = _mm256_castps256_ps128(sums);
483                let sum128 = _mm_add_ps(low, high);
484
485                let shuf = _mm_shuffle_ps(sum128, sum128, 0b1110);
486                let sum_partial = _mm_add_ps(sum128, shuf);
487                let shuf2 = _mm_shuffle_ps(sum_partial, sum_partial, 0b0001);
488                let final_sum = _mm_add_ps(sum_partial, shuf2);
489
490                let mut result = _mm_cvtss_f32(final_sum);
491
492                // Handle remaining elements
493                for j in i..len {
494                    result += a[j] * b[j];
495                }
496
497                return Ok(result);
498            }
499        }
500    }
501
502    // Fallback
503    let mut sum = 0.0f32;
504    for i in 0..len {
505        sum += a[i] * b[i];
506    }
507    Ok(sum)
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    #[test]
515    fn test_aligned_vec_creation() {
516        let mut vec = AlignedVec::<f32>::with_capacity(16).expect("Operation failed");
517        assert_eq!(vec.len(), 0);
518        assert_eq!(vec.capacity(), 16);
519
520        vec.push(1.0);
521        vec.push(2.0);
522        assert_eq!(vec.len(), 2);
523        assert_eq!(vec.as_slice(), &[1.0, 2.0]);
524    }
525
526    #[test]
527    fn test_simd_add_aligned() {
528        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
529        let b = vec![10.0f32, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
530
531        let result = simd_add_aligned_f32(&a, &b).expect("Operation failed");
532        let expected = vec![11.0f32; 10];
533
534        assert_eq!(result.as_slice(), &expected);
535    }
536
537    #[test]
538    fn test_simd_dot_aligned() {
539        let a = vec![1.0f32, 2.0, 3.0, 4.0];
540        let b = vec![5.0f32, 6.0, 7.0, 8.0];
541
542        let result = simd_dot_aligned_f32(&a, &b).expect("Operation failed");
543        let expected = 1.0 * 5.0 + 2.0 * 6.0 + 3.0 * 7.0 + 4.0 * 8.0; // = 70.0
544
545        assert!((result - expected).abs() < 1e-6);
546    }
547
548    #[test]
549    fn test_alignment() {
550        let mut vec = AlignedVec::<f32>::with_capacity(32).expect("Operation failed");
551        // Add some elements to ensure non-empty vector
552        vec.push(1.0);
553        vec.push(2.0);
554        vec.push(3.0);
555        vec.push(4.0);
556
557        let ptr = vec.as_slice().as_ptr() as usize;
558        assert_eq!(ptr % SIMD_ALIGNMENT, 0, "Vector should be properly aligned");
559    }
560}