Skip to main content

yscv_tensor/
aligned.rs

1//! A Vec-like container guaranteeing 32-byte aligned allocation for SIMD operations.
2
3use std::fmt;
4use std::iter::FromIterator;
5use std::ops::{Deref, DerefMut};
6
7/// Alignment in bytes required for AVX operations.
8// WHY 32: AVX/AVX2 aligned load/store (_mm256_load_ps) requires 32-byte alignment; also satisfies NEON (16B).
9const ALIGN: usize = 32;
10
11/// A Vec-like container that guarantees 32-byte alignment for the data pointer.
12///
13/// This is required for AVX/AVX2 SIMD instructions which expect 32-byte aligned
14/// memory. Standard `Vec<f32>` only guarantees 4-byte alignment.
15pub struct AlignedVec<T> {
16    ptr: *mut T,
17    len: usize,
18    cap: usize,
19}
20
21// SAFETY: AlignedVec owns its data exclusively, so it is safe to Send/Sync
22// when T is Send/Sync (same guarantees as Vec<T>).
23#[allow(unsafe_code)]
24unsafe impl<T: Send> Send for AlignedVec<T> {}
25#[allow(unsafe_code)]
26unsafe impl<T: Sync> Sync for AlignedVec<T> {}
27
28impl<T> Default for AlignedVec<T> {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34impl<T> AlignedVec<T> {
35    /// Creates a new empty `AlignedVec` with no allocation.
36    #[inline]
37    pub fn new() -> Self {
38        Self {
39            ptr: std::ptr::NonNull::dangling().as_ptr(),
40            len: 0,
41            cap: 0,
42        }
43    }
44
45    /// Creates an `AlignedVec` with the given capacity (aligned to 32 bytes).
46    #[allow(unsafe_code)]
47    pub fn with_capacity(cap: usize) -> Self {
48        if cap == 0 {
49            return Self::new();
50        }
51        let ptr = alloc_aligned::<T>(cap);
52        Self { ptr, len: 0, cap }
53    }
54
55    /// Creates an `AlignedVec` from an existing `Vec<T>`, copying data into
56    /// aligned storage.
57    #[allow(unsafe_code)]
58    pub fn from_vec(v: Vec<T>) -> Self {
59        let len = v.len();
60        if len == 0 {
61            return Self::new();
62        }
63        let ptr = alloc_aligned::<T>(len);
64        // SAFETY: ptr is valid for len elements, v.as_ptr() is valid for len elements,
65        // they do not overlap (fresh allocation).
66        unsafe {
67            std::ptr::copy_nonoverlapping(v.as_ptr(), ptr, len);
68        }
69        // Prevent Vec from dropping the elements (we moved them via copy).
70        // For Copy types this is fine. For non-Copy types that impl Drop,
71        // we must prevent double-free.
72        std::mem::forget(v);
73        Self { ptr, len, cap: len }
74    }
75
76    /// Returns the number of elements.
77    #[inline]
78    pub fn len(&self) -> usize {
79        self.len
80    }
81
82    /// Returns `true` if the vec is empty.
83    #[inline]
84    pub fn is_empty(&self) -> bool {
85        self.len == 0
86    }
87
88    /// Returns a raw pointer to the aligned data.
89    #[inline]
90    pub fn as_ptr(&self) -> *const T {
91        self.ptr
92    }
93
94    /// Returns a mutable raw pointer to the aligned data.
95    #[inline]
96    pub fn as_mut_ptr(&mut self) -> *mut T {
97        self.ptr
98    }
99
100    /// Returns an immutable slice over the contained elements.
101    #[inline]
102    #[allow(unsafe_code)]
103    pub fn as_slice(&self) -> &[T] {
104        if self.len == 0 {
105            return &[];
106        }
107        // SAFETY: ptr is valid for self.len elements, properly aligned.
108        unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
109    }
110
111    /// Returns a mutable slice over the contained elements.
112    #[inline]
113    #[allow(unsafe_code)]
114    pub fn as_mut_slice(&mut self) -> &mut [T] {
115        if self.len == 0 {
116            return &mut [];
117        }
118        // SAFETY: ptr is valid for self.len elements, properly aligned, uniquely borrowed.
119        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
120    }
121
122    /// Appends a value to the end, growing if necessary.
123    #[allow(unsafe_code)]
124    pub fn push(&mut self, val: T) {
125        if self.len == self.cap {
126            self.grow();
127        }
128        // SAFETY: after grow, self.cap > self.len, so ptr.add(self.len) is valid.
129        unsafe {
130            self.ptr.add(self.len).write(val);
131        }
132        self.len += 1;
133    }
134
135    /// Grows the backing allocation (doubles capacity, minimum 4).
136    #[allow(unsafe_code)]
137    fn grow(&mut self) {
138        let new_cap = if self.cap == 0 {
139            4
140        } else {
141            self.cap.checked_mul(2).expect("capacity overflow")
142        };
143        let new_ptr = alloc_aligned::<T>(new_cap);
144        if self.cap > 0 {
145            if self.len > 0 {
146                // SAFETY: old ptr valid for self.len elements, new ptr valid for new_cap >= self.len.
147                unsafe {
148                    std::ptr::copy_nonoverlapping(self.ptr, new_ptr, self.len);
149                }
150            }
151            dealloc_aligned::<T>(self.ptr, self.cap);
152        }
153        self.ptr = new_ptr;
154        self.cap = new_cap;
155    }
156}
157
158impl<T: Copy> AlignedVec<T> {
159    /// Creates an `AlignedVec` with `len` elements of **uninitialized** memory.
160    ///
161    /// The caller **must** write every element before reading any of them.
162    /// This avoids the cost of zeroing the buffer when a subsequent SIMD pass
163    /// will overwrite every element anyway.
164    ///
165    /// # Safety
166    /// The allocation is valid and aligned, but the contents are indeterminate.
167    /// Reading before writing is undefined behaviour.
168    #[allow(unsafe_code)]
169    #[inline]
170    pub fn uninitialized(len: usize) -> Self {
171        if len == 0 {
172            return Self::new();
173        }
174        let ptr = alloc_aligned::<T>(len);
175        Self { ptr, len, cap: len }
176    }
177}
178
179impl<T: Default + Copy> AlignedVec<T> {
180    /// Creates an `AlignedVec` of `len` elements, each set to `val`.
181    #[allow(unsafe_code)]
182    pub fn filled(len: usize, val: T) -> Self {
183        if len == 0 {
184            return Self::new();
185        }
186        let ptr = alloc_aligned::<T>(len);
187        // SAFETY: ptr is valid for len elements.
188        unsafe {
189            for i in 0..len {
190                ptr.add(i).write(val);
191            }
192        }
193        Self { ptr, len, cap: len }
194    }
195
196    /// Creates an `AlignedVec` of `len` zero/default elements.
197    pub fn zeros(len: usize) -> Self {
198        Self::filled(len, T::default())
199    }
200
201    /// Creates an `AlignedVec` of `len` zero-initialized elements using `alloc_zeroed`.
202    ///
203    /// This is faster than `filled(len, T::default())` for large allocations because
204    /// the OS can provide pre-zeroed pages without writing every byte.
205    ///
206    /// # Safety requirement
207    /// Only correct when all-zero bytes is a valid representation for `T` (true for
208    /// all primitive numeric types like f32, u8, i32, etc.).
209    #[allow(unsafe_code)]
210    pub fn calloc(len: usize) -> Self {
211        if len == 0 {
212            return Self::new();
213        }
214        let size = len
215            .checked_mul(std::mem::size_of::<T>())
216            .expect("allocation size overflow");
217        let size = size.max(1);
218        let layout =
219            std::alloc::Layout::from_size_align(size, ALIGN).expect("invalid allocation layout");
220        // SAFETY: layout has non-zero size. alloc_zeroed returns zeroed memory.
221        let ptr = unsafe { std::alloc::alloc_zeroed(layout) };
222        if ptr.is_null() {
223            std::alloc::handle_alloc_error(layout);
224        }
225        Self {
226            ptr: ptr as *mut T,
227            len,
228            cap: len,
229        }
230    }
231}
232
233impl<T> Drop for AlignedVec<T> {
234    #[allow(unsafe_code)]
235    fn drop(&mut self) {
236        if self.cap == 0 {
237            return;
238        }
239        // Drop elements in place.
240        unsafe {
241            std::ptr::drop_in_place(std::ptr::slice_from_raw_parts_mut(self.ptr, self.len));
242        }
243        dealloc_aligned::<T>(self.ptr, self.cap);
244    }
245}
246
247impl<T: Clone> Clone for AlignedVec<T> {
248    #[allow(unsafe_code)]
249    fn clone(&self) -> Self {
250        if self.len == 0 {
251            return Self::new();
252        }
253        let ptr = alloc_aligned::<T>(self.len);
254        // SAFETY: ptr valid for self.len elements, we write each one via clone.
255        unsafe {
256            for i in 0..self.len {
257                ptr.add(i).write((*self.ptr.add(i)).clone());
258            }
259        }
260        Self {
261            ptr,
262            len: self.len,
263            cap: self.len,
264        }
265    }
266}
267
268impl<T: fmt::Debug> fmt::Debug for AlignedVec<T> {
269    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270        f.debug_list().entries(self.as_slice().iter()).finish()
271    }
272}
273
274impl<T: PartialEq> PartialEq for AlignedVec<T> {
275    fn eq(&self, other: &Self) -> bool {
276        self.as_slice() == other.as_slice()
277    }
278}
279
280impl<T> Deref for AlignedVec<T> {
281    type Target = [T];
282
283    #[inline]
284    fn deref(&self) -> &[T] {
285        self.as_slice()
286    }
287}
288
289impl<T> DerefMut for AlignedVec<T> {
290    #[inline]
291    fn deref_mut(&mut self) -> &mut [T] {
292        self.as_mut_slice()
293    }
294}
295
296impl<T> FromIterator<T> for AlignedVec<T> {
297    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
298        let iter = iter.into_iter();
299        let (lower, _) = iter.size_hint();
300        let mut v = AlignedVec::with_capacity(lower);
301        for item in iter {
302            v.push(item);
303        }
304        v
305    }
306}
307
308// ── Allocation helpers ─────────────────────────────────────────────
309
310/// Allocates aligned memory for `count` elements of type `T`.
311/// Thread-local cache of freed aligned allocations.
312/// Avoids repeated mmap/munmap for same-size buffers (like PyTorch's CachingAllocator).
313/// Each entry: (pointer, byte_size).
314#[allow(unsafe_code)]
315mod alloc_cache {
316
317    use std::cell::RefCell;
318
319    // WHY 8: 8 cached allocations per thread balances memory reuse with bounded memory growth.
320    const MAX_CACHED: usize = 8;
321    const ALIGN: usize = super::ALIGN;
322
323    /// RAII wrapper that properly frees all cached aligned allocations when dropped.
324    /// This prevents SIGSEGV/SIGABRT during process exit from leaked cached pointers.
325    struct AllocCache {
326        entries: Vec<(*mut u8, usize)>,
327    }
328
329    impl AllocCache {
330        const fn new() -> Self {
331            Self {
332                entries: Vec::new(),
333            }
334        }
335    }
336
337    impl Drop for AllocCache {
338        fn drop(&mut self) {
339            for &(ptr, size) in &self.entries {
340                if !ptr.is_null()
341                    && let Ok(layout) = std::alloc::Layout::from_size_align(size, ALIGN)
342                {
343                    unsafe {
344                        std::alloc::dealloc(ptr, layout);
345                    }
346                }
347            }
348        }
349    }
350
351    thread_local! {
352        static CACHE: RefCell<AllocCache> = const { RefCell::new(AllocCache::new()) };
353    }
354
355    pub(super) fn try_alloc(size: usize) -> Option<*mut u8> {
356        if cfg!(miri) {
357            return None;
358        } // Disable cache under Miri to avoid false leak reports
359        // Use try_with to gracefully handle TLS already destroyed during thread/process exit
360        CACHE
361            .try_with(|c| {
362                let mut cache = c.borrow_mut();
363                if let Some(pos) = cache.entries.iter().position(|&(_, s)| s == size) {
364                    let (ptr, _) = cache.entries.swap_remove(pos);
365                    Some(ptr)
366                } else {
367                    None
368                }
369            })
370            .ok()
371            .flatten()
372    }
373
374    pub(super) fn try_dealloc(ptr: *mut u8, size: usize) -> bool {
375        if cfg!(miri) {
376            return false;
377        } // Always free under Miri to avoid false leak reports
378        // Use try_with: if TLS is destroyed (thread exiting), fall through to real dealloc
379        CACHE
380            .try_with(|c| {
381                let mut cache = c.borrow_mut();
382                if cache.entries.len() < MAX_CACHED {
383                    cache.entries.push((ptr, size));
384                    true
385                } else {
386                    false
387                }
388            })
389            .unwrap_or(false)
390    }
391}
392
393#[allow(unsafe_code)]
394fn alloc_aligned<T>(count: usize) -> *mut T {
395    assert!(count > 0, "cannot allocate zero-sized aligned buffer");
396    let size = count
397        .checked_mul(std::mem::size_of::<T>())
398        .expect("allocation size overflow");
399    let size = size.max(1);
400
401    // Try thread-local cache first
402    if let Some(ptr) = alloc_cache::try_alloc(size) {
403        return ptr as *mut T;
404    }
405
406    let layout =
407        std::alloc::Layout::from_size_align(size, ALIGN).expect("invalid allocation layout");
408    // SAFETY: layout has non-zero size.
409    let ptr = unsafe { std::alloc::alloc(layout) };
410    if ptr.is_null() {
411        std::alloc::handle_alloc_error(layout);
412    }
413    ptr as *mut T
414}
415
416/// Deallocates aligned memory previously allocated with `alloc_aligned`.
417#[allow(unsafe_code)]
418fn dealloc_aligned<T>(ptr: *mut T, cap: usize) {
419    let size = match cap.checked_mul(std::mem::size_of::<T>()) {
420        Some(s) => s.max(1),
421        None => return, // overflow — cannot reconstruct layout, leak rather than panic in Drop
422    };
423
424    // Try to cache instead of freeing
425    if alloc_cache::try_dealloc(ptr as *mut u8, size) {
426        return;
427    }
428
429    if let Ok(layout) = std::alloc::Layout::from_size_align(size, ALIGN) {
430        // SAFETY: ptr was allocated with this layout via alloc_aligned.
431        unsafe {
432            std::alloc::dealloc(ptr as *mut u8, layout);
433        }
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440
441    #[test]
442    fn alignment_is_32_bytes() {
443        let v = AlignedVec::<f32>::with_capacity(64);
444        assert_eq!(v.as_ptr() as usize % 32, 0);
445    }
446
447    #[test]
448    fn filled_and_len() {
449        let v = AlignedVec::filled(100, 1.0f32);
450        assert_eq!(v.len(), 100);
451        assert!(v.iter().all(|&x| x == 1.0));
452        assert_eq!(v.as_ptr() as usize % 32, 0);
453    }
454
455    #[test]
456    fn zeros_default() {
457        let v = AlignedVec::<f32>::zeros(16);
458        assert_eq!(v.len(), 16);
459        assert!(v.iter().all(|&x| x == 0.0));
460    }
461
462    #[test]
463    fn from_vec_copies_and_aligns() {
464        let orig = vec![1.0f32, 2.0, 3.0, 4.0];
465        let aligned = AlignedVec::from_vec(orig);
466        assert_eq!(aligned.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
467        assert_eq!(aligned.as_ptr() as usize % 32, 0);
468    }
469
470    #[test]
471    fn push_and_grow() {
472        let mut v = AlignedVec::<f32>::new();
473        for i in 0..100 {
474            v.push(i as f32);
475        }
476        assert_eq!(v.len(), 100);
477        assert_eq!(v[50], 50.0);
478        assert_eq!(v.as_ptr() as usize % 32, 0);
479    }
480
481    #[test]
482    fn clone_preserves_alignment() {
483        let v = AlignedVec::filled(10, 42.0f32);
484        let v2 = v.clone();
485        assert_eq!(v.as_slice(), v2.as_slice());
486        assert_eq!(v2.as_ptr() as usize % 32, 0);
487    }
488
489    #[test]
490    fn from_iterator() {
491        let v: AlignedVec<f32> = (0..10).map(|i| i as f32).collect();
492        assert_eq!(v.len(), 10);
493        assert_eq!(v[5], 5.0);
494        assert_eq!(v.as_ptr() as usize % 32, 0);
495    }
496
497    #[test]
498    fn empty_vec_operations() {
499        let v = AlignedVec::<f32>::new();
500        assert!(v.is_empty());
501        assert_eq!(v.len(), 0);
502        assert_eq!(v.as_slice(), &[] as &[f32]);
503    }
504
505    #[test]
506    fn deref_slice_access() {
507        let v = AlignedVec::filled(5, 3.0f32);
508        // Test that Deref to [T] works
509        let sum: f32 = v.iter().sum();
510        assert_eq!(sum, 15.0);
511    }
512}