Skip to main content

yscv_tensor/
aligned.rs

1//! A Vec-like container guaranteeing 32-byte aligned allocation for SIMD operations.
2
3use std::fmt;
4use std::iter::FromIterator;
5use std::ops::{Deref, DerefMut};
6
7/// Alignment in bytes required for AVX operations.
8// WHY 32: AVX/AVX2 aligned load/store (_mm256_load_ps) requires 32-byte alignment; also satisfies NEON (16B).
9const ALIGN: usize = 32;
10
11/// A Vec-like container that guarantees 32-byte alignment for the data pointer.
12///
13/// This is required for AVX/AVX2 SIMD instructions which expect 32-byte aligned
14/// memory. Standard `Vec<f32>` only guarantees 4-byte alignment.
15pub struct AlignedVec<T> {
16    ptr: *mut T,
17    len: usize,
18    cap: usize,
19}
20
21// SAFETY: AlignedVec owns its data exclusively, so it is safe to Send/Sync
22// when T is Send/Sync (same guarantees as Vec<T>).
23#[allow(unsafe_code)]
24unsafe impl<T: Send> Send for AlignedVec<T> {}
25#[allow(unsafe_code)]
26unsafe impl<T: Sync> Sync for AlignedVec<T> {}
27
28impl<T> Default for AlignedVec<T> {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34impl<T> AlignedVec<T> {
35    /// Creates a new empty `AlignedVec` with no allocation.
36    #[inline]
37    pub fn new() -> Self {
38        Self {
39            ptr: std::ptr::NonNull::dangling().as_ptr(),
40            len: 0,
41            cap: 0,
42        }
43    }
44
45    /// Creates an `AlignedVec` with the given capacity (aligned to 32 bytes).
46    #[allow(unsafe_code)]
47    pub fn with_capacity(cap: usize) -> Self {
48        if cap == 0 {
49            return Self::new();
50        }
51        let ptr = alloc_aligned::<T>(cap);
52        Self { ptr, len: 0, cap }
53    }
54
55    /// Creates an `AlignedVec` from an existing `Vec<T>`, copying data into
56    /// aligned storage. Requires `T: Copy` to ensure bitwise copy is sound.
57    #[allow(unsafe_code)]
58    pub fn from_vec(v: Vec<T>) -> Self
59    where
60        T: Copy,
61    {
62        let len = v.len();
63        if len == 0 {
64            return Self::new();
65        }
66        let ptr = alloc_aligned::<T>(len);
67        // SAFETY: ptr is valid for len elements, v.as_ptr() is valid for len elements,
68        // they do not overlap (fresh allocation).
69        unsafe {
70            std::ptr::copy_nonoverlapping(v.as_ptr(), ptr, len);
71        }
72        // v is dropped here normally — T: Copy so no element destructors,
73        // and Vec's drop frees the backing allocation.
74        drop(v);
75        Self { ptr, len, cap: len }
76    }
77
78    /// Returns the number of elements.
79    #[inline]
80    pub fn len(&self) -> usize {
81        self.len
82    }
83
84    /// Returns `true` if the vec is empty.
85    #[inline]
86    pub fn is_empty(&self) -> bool {
87        self.len == 0
88    }
89
90    /// Returns a raw pointer to the aligned data.
91    #[inline]
92    pub fn as_ptr(&self) -> *const T {
93        self.ptr
94    }
95
96    /// Returns a mutable raw pointer to the aligned data.
97    #[inline]
98    pub fn as_mut_ptr(&mut self) -> *mut T {
99        self.ptr
100    }
101
102    /// Returns an immutable slice over the contained elements.
103    #[inline]
104    #[allow(unsafe_code)]
105    pub fn as_slice(&self) -> &[T] {
106        if self.len == 0 {
107            return &[];
108        }
109        // SAFETY: ptr is valid for self.len elements, properly aligned.
110        unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
111    }
112
113    /// Returns a mutable slice over the contained elements.
114    #[inline]
115    #[allow(unsafe_code)]
116    pub fn as_mut_slice(&mut self) -> &mut [T] {
117        if self.len == 0 {
118            return &mut [];
119        }
120        // SAFETY: ptr is valid for self.len elements, properly aligned, uniquely borrowed.
121        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
122    }
123
124    /// Appends a value to the end, growing if necessary.
125    #[allow(unsafe_code)]
126    pub fn push(&mut self, val: T) {
127        if self.len == self.cap {
128            self.grow();
129        }
130        // SAFETY: after grow, self.cap > self.len, so ptr.add(self.len) is valid.
131        unsafe {
132            self.ptr.add(self.len).write(val);
133        }
134        self.len += 1;
135    }
136
137    /// Grows the backing allocation (doubles capacity, minimum 4).
138    #[allow(unsafe_code)]
139    fn grow(&mut self) {
140        let new_cap = if self.cap == 0 {
141            4
142        } else {
143            self.cap.checked_mul(2).expect("capacity overflow")
144        };
145        let new_ptr = alloc_aligned::<T>(new_cap);
146        if self.cap > 0 {
147            if self.len > 0 {
148                // SAFETY: old ptr valid for self.len elements, new ptr valid for new_cap >= self.len.
149                unsafe {
150                    std::ptr::copy_nonoverlapping(self.ptr, new_ptr, self.len);
151                }
152            }
153            dealloc_aligned::<T>(self.ptr, self.cap);
154        }
155        self.ptr = new_ptr;
156        self.cap = new_cap;
157    }
158}
159
160impl<T: Copy> AlignedVec<T> {
161    /// Creates an `AlignedVec` with `len` elements of **uninitialized** memory.
162    ///
163    /// The caller **must** write every element before reading any of them.
164    /// This avoids the cost of zeroing the buffer when a subsequent SIMD pass
165    /// will overwrite every element anyway.
166    ///
167    /// # Safety
168    /// The allocation is valid and aligned, but the contents are indeterminate.
169    /// Reading before writing is undefined behaviour.
170    #[allow(unsafe_code)]
171    #[inline]
172    pub fn uninitialized(len: usize) -> Self {
173        if len == 0 {
174            return Self::new();
175        }
176        let ptr = alloc_aligned::<T>(len);
177        Self { ptr, len, cap: len }
178    }
179}
180
181impl<T: Default + Copy> AlignedVec<T> {
182    /// Creates an `AlignedVec` of `len` elements, each set to `val`.
183    #[allow(unsafe_code)]
184    pub fn filled(len: usize, val: T) -> Self {
185        if len == 0 {
186            return Self::new();
187        }
188        let ptr = alloc_aligned::<T>(len);
189        // SAFETY: ptr is valid for len elements.
190        unsafe {
191            for i in 0..len {
192                ptr.add(i).write(val);
193            }
194        }
195        Self { ptr, len, cap: len }
196    }
197
198    /// Creates an `AlignedVec` of `len` zero/default elements.
199    pub fn zeros(len: usize) -> Self {
200        Self::filled(len, T::default())
201    }
202
203    /// Creates an `AlignedVec` of `len` zero-initialized elements using `alloc_zeroed`.
204    ///
205    /// This is faster than `filled(len, T::default())` for large allocations because
206    /// the OS can provide pre-zeroed pages without writing every byte.
207    ///
208    /// # Safety requirement
209    /// Only correct when all-zero bytes is a valid representation for `T` (true for
210    /// all primitive numeric types like f32, u8, i32, etc.).
211    #[allow(unsafe_code)]
212    pub fn calloc(len: usize) -> Self {
213        if len == 0 {
214            return Self::new();
215        }
216        let size = len
217            .checked_mul(std::mem::size_of::<T>())
218            .expect("allocation size overflow");
219        let size = size.max(1);
220        let layout =
221            std::alloc::Layout::from_size_align(size, ALIGN).expect("invalid allocation layout");
222        // SAFETY: layout has non-zero size. alloc_zeroed returns zeroed memory.
223        let ptr = unsafe { std::alloc::alloc_zeroed(layout) };
224        if ptr.is_null() {
225            std::alloc::handle_alloc_error(layout);
226        }
227        Self {
228            ptr: ptr as *mut T,
229            len,
230            cap: len,
231        }
232    }
233}
234
235impl<T> Drop for AlignedVec<T> {
236    #[allow(unsafe_code)]
237    fn drop(&mut self) {
238        if self.cap == 0 {
239            return;
240        }
241        // Drop elements in place.
242        unsafe {
243            std::ptr::drop_in_place(std::ptr::slice_from_raw_parts_mut(self.ptr, self.len));
244        }
245        dealloc_aligned::<T>(self.ptr, self.cap);
246    }
247}
248
249impl<T: Clone> Clone for AlignedVec<T> {
250    #[allow(unsafe_code)]
251    fn clone(&self) -> Self {
252        if self.len == 0 {
253            return Self::new();
254        }
255        let ptr = alloc_aligned::<T>(self.len);
256        // SAFETY: ptr valid for self.len elements, we write each one via clone.
257        unsafe {
258            for i in 0..self.len {
259                ptr.add(i).write((*self.ptr.add(i)).clone());
260            }
261        }
262        Self {
263            ptr,
264            len: self.len,
265            cap: self.len,
266        }
267    }
268}
269
270impl<T: fmt::Debug> fmt::Debug for AlignedVec<T> {
271    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
272        f.debug_list().entries(self.as_slice().iter()).finish()
273    }
274}
275
276impl<T: PartialEq> PartialEq for AlignedVec<T> {
277    fn eq(&self, other: &Self) -> bool {
278        self.as_slice() == other.as_slice()
279    }
280}
281
282impl<T> Deref for AlignedVec<T> {
283    type Target = [T];
284
285    #[inline]
286    fn deref(&self) -> &[T] {
287        self.as_slice()
288    }
289}
290
291impl<T> DerefMut for AlignedVec<T> {
292    #[inline]
293    fn deref_mut(&mut self) -> &mut [T] {
294        self.as_mut_slice()
295    }
296}
297
298impl<T> FromIterator<T> for AlignedVec<T> {
299    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
300        let iter = iter.into_iter();
301        let (lower, _) = iter.size_hint();
302        let mut v = AlignedVec::with_capacity(lower);
303        for item in iter {
304            v.push(item);
305        }
306        v
307    }
308}
309
310// ── Allocation helpers ─────────────────────────────────────────────
311
312/// Allocates aligned memory for `count` elements of type `T`.
313/// Thread-local cache of freed aligned allocations.
314/// Avoids repeated mmap/munmap for same-size buffers (like PyTorch's CachingAllocator).
315/// Each entry: (pointer, byte_size).
316#[allow(unsafe_code)]
317mod alloc_cache {
318
319    use std::cell::RefCell;
320
321    // WHY 8: 8 cached allocations per thread balances memory reuse with bounded memory growth.
322    const MAX_CACHED: usize = 8;
323    const ALIGN: usize = super::ALIGN;
324
325    /// RAII wrapper that properly frees all cached aligned allocations when dropped.
326    /// This prevents SIGSEGV/SIGABRT during process exit from leaked cached pointers.
327    struct AllocCache {
328        entries: Vec<(*mut u8, usize)>,
329    }
330
331    impl AllocCache {
332        const fn new() -> Self {
333            Self {
334                entries: Vec::new(),
335            }
336        }
337    }
338
339    impl Drop for AllocCache {
340        fn drop(&mut self) {
341            for &(ptr, size) in &self.entries {
342                if !ptr.is_null()
343                    && let Ok(layout) = std::alloc::Layout::from_size_align(size, ALIGN)
344                {
345                    unsafe {
346                        std::alloc::dealloc(ptr, layout);
347                    }
348                }
349            }
350        }
351    }
352
353    thread_local! {
354        static CACHE: RefCell<AllocCache> = const { RefCell::new(AllocCache::new()) };
355    }
356
357    pub(super) fn try_alloc(size: usize) -> Option<*mut u8> {
358        if cfg!(miri) {
359            return None;
360        } // Disable cache under Miri to avoid false leak reports
361        // Use try_with to gracefully handle TLS already destroyed during thread/process exit
362        CACHE
363            .try_with(|c| {
364                let mut cache = c.borrow_mut();
365                if let Some(pos) = cache.entries.iter().position(|&(_, s)| s == size) {
366                    let (ptr, _) = cache.entries.swap_remove(pos);
367                    Some(ptr)
368                } else {
369                    None
370                }
371            })
372            .ok()
373            .flatten()
374    }
375
376    pub(super) fn try_dealloc(ptr: *mut u8, size: usize) -> bool {
377        if cfg!(miri) {
378            return false;
379        } // Always free under Miri to avoid false leak reports
380        // Use try_with: if TLS is destroyed (thread exiting), fall through to real dealloc
381        CACHE
382            .try_with(|c| {
383                let mut cache = c.borrow_mut();
384                if cache.entries.len() < MAX_CACHED {
385                    cache.entries.push((ptr, size));
386                    true
387                } else {
388                    false
389                }
390            })
391            .unwrap_or(false)
392    }
393}
394
395#[allow(unsafe_code)]
396fn alloc_aligned<T>(count: usize) -> *mut T {
397    assert!(count > 0, "cannot allocate zero-sized aligned buffer");
398    let size = count
399        .checked_mul(std::mem::size_of::<T>())
400        .expect("allocation size overflow");
401    let size = size.max(1);
402
403    // Try thread-local cache first
404    if let Some(ptr) = alloc_cache::try_alloc(size) {
405        return ptr as *mut T;
406    }
407
408    let layout =
409        std::alloc::Layout::from_size_align(size, ALIGN).expect("invalid allocation layout");
410    // SAFETY: layout has non-zero size.
411    let ptr = unsafe { std::alloc::alloc(layout) };
412    if ptr.is_null() {
413        std::alloc::handle_alloc_error(layout);
414    }
415    ptr as *mut T
416}
417
418/// Deallocates aligned memory previously allocated with `alloc_aligned`.
419#[allow(unsafe_code)]
420fn dealloc_aligned<T>(ptr: *mut T, cap: usize) {
421    let size = match cap.checked_mul(std::mem::size_of::<T>()) {
422        Some(s) => s.max(1),
423        None => return, // overflow — cannot reconstruct layout, leak rather than panic in Drop
424    };
425
426    // Try to cache instead of freeing
427    if alloc_cache::try_dealloc(ptr as *mut u8, size) {
428        return;
429    }
430
431    if let Ok(layout) = std::alloc::Layout::from_size_align(size, ALIGN) {
432        // SAFETY: ptr was allocated with this layout via alloc_aligned.
433        unsafe {
434            std::alloc::dealloc(ptr as *mut u8, layout);
435        }
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    #[test]
444    fn alignment_is_32_bytes() {
445        let v = AlignedVec::<f32>::with_capacity(64);
446        assert_eq!(v.as_ptr() as usize % 32, 0);
447    }
448
449    #[test]
450    fn filled_and_len() {
451        let v = AlignedVec::filled(100, 1.0f32);
452        assert_eq!(v.len(), 100);
453        assert!(v.iter().all(|&x| x == 1.0));
454        assert_eq!(v.as_ptr() as usize % 32, 0);
455    }
456
457    #[test]
458    fn zeros_default() {
459        let v = AlignedVec::<f32>::zeros(16);
460        assert_eq!(v.len(), 16);
461        assert!(v.iter().all(|&x| x == 0.0));
462    }
463
464    #[test]
465    fn from_vec_copies_and_aligns() {
466        let orig = vec![1.0f32, 2.0, 3.0, 4.0];
467        let aligned = AlignedVec::from_vec(orig);
468        assert_eq!(aligned.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
469        assert_eq!(aligned.as_ptr() as usize % 32, 0);
470    }
471
472    #[test]
473    fn push_and_grow() {
474        let mut v = AlignedVec::<f32>::new();
475        for i in 0..100 {
476            v.push(i as f32);
477        }
478        assert_eq!(v.len(), 100);
479        assert_eq!(v[50], 50.0);
480        assert_eq!(v.as_ptr() as usize % 32, 0);
481    }
482
483    #[test]
484    fn clone_preserves_alignment() {
485        let v = AlignedVec::filled(10, 42.0f32);
486        let v2 = v.clone();
487        assert_eq!(v.as_slice(), v2.as_slice());
488        assert_eq!(v2.as_ptr() as usize % 32, 0);
489    }
490
491    #[test]
492    fn from_iterator() {
493        let v: AlignedVec<f32> = (0..10).map(|i| i as f32).collect();
494        assert_eq!(v.len(), 10);
495        assert_eq!(v[5], 5.0);
496        assert_eq!(v.as_ptr() as usize % 32, 0);
497    }
498
499    #[test]
500    fn empty_vec_operations() {
501        let v = AlignedVec::<f32>::new();
502        assert!(v.is_empty());
503        assert_eq!(v.len(), 0);
504        assert_eq!(v.as_slice(), &[] as &[f32]);
505    }
506
507    #[test]
508    fn deref_slice_access() {
509        let v = AlignedVec::filled(5, 3.0f32);
510        // Test that Deref to [T] works
511        let sum: f32 = v.iter().sum();
512        assert_eq!(sum, 15.0);
513    }
514}