Skip to main content

solana_sbpf/
aligned_memory.rs

1//! Aligned memory
2
3use std::{
4    alloc::{alloc, alloc_zeroed, dealloc, handle_alloc_error, Layout},
5    mem,
6    ptr::NonNull,
7};
8
9/// Scalar types, aka "plain old data"
10pub trait Pod: Copy {}
11
12impl Pod for u8 {}
13impl Pod for u16 {}
14impl Pod for u32 {}
15impl Pod for u64 {}
16impl Pod for i8 {}
17impl Pod for i16 {}
18impl Pod for i32 {}
19impl Pod for i64 {}
20
21/// Provides u8 slices at a specified alignment
22#[derive(Debug, PartialEq, Eq)]
23pub struct AlignedMemory<const ALIGN: usize> {
24    mem: AlignedVec<ALIGN>,
25    zero_up_to_max_len: bool,
26}
27
28impl<const ALIGN: usize> AlignedMemory<ALIGN> {
29    /// Returns a filled AlignedMemory by copying the given slice
30    pub fn from_slice(data: &[u8]) -> Self {
31        let max_len = data.len();
32        let mut mem = AlignedVec::new(max_len, false);
33        unsafe {
34            // SAFETY: `mem` was allocated with `max_len` bytes
35            core::ptr::copy_nonoverlapping(data.as_ptr(), mem.as_mut_ptr(), max_len);
36            mem.set_len(max_len);
37        }
38        Self {
39            mem,
40            zero_up_to_max_len: false,
41        }
42    }
43
44    /// Returns a new empty AlignedMemory with uninitialized preallocated memory
45    pub fn with_capacity(max_len: usize) -> Self {
46        let mem = AlignedVec::new(max_len, false);
47        Self {
48            mem,
49            zero_up_to_max_len: false,
50        }
51    }
52
53    /// Returns a new empty AlignedMemory with zero initialized preallocated memory
54    pub fn with_capacity_zeroed(max_len: usize) -> Self {
55        let mem = AlignedVec::new(max_len, true);
56        Self {
57            mem,
58            zero_up_to_max_len: true,
59        }
60    }
61
62    /// Returns a new filled AlignedMemory with zero initialized preallocated memory
63    pub fn zero_filled(max_len: usize) -> Self {
64        let mut mem = AlignedVec::new(max_len, true);
65        // SAFETY: Bytes were zeroed
66        unsafe {
67            mem.set_len(max_len);
68        }
69        Self {
70            mem,
71            zero_up_to_max_len: true,
72        }
73    }
74
75    /// Calculate memory size (allocated memory block and the size of [`AlignedMemory`] itself).
76    pub fn mem_size(&self) -> usize {
77        self.mem.capacity().saturating_add(mem::size_of::<Self>())
78    }
79
80    /// Get the length of the data
81    pub fn len(&self) -> usize {
82        self.mem.len()
83    }
84
85    /// Is the memory empty
86    pub fn is_empty(&self) -> bool {
87        self.mem.is_empty()
88    }
89
90    /// Get the current write index
91    pub fn write_index(&self) -> usize {
92        self.mem.len()
93    }
94
95    /// Get an aligned slice
96    pub fn as_slice(&self) -> &[u8] {
97        self.mem.as_slice()
98    }
99
100    /// Get an aligned mutable slice
101    pub fn as_slice_mut(&mut self) -> &mut [u8] {
102        self.mem.as_slice_mut()
103    }
104
105    /// Grows memory with `value` repeated `num` times starting at the `write_index`
106    pub fn fill_write(&mut self, num: usize, value: u8) -> std::io::Result<()> {
107        let (ptr, new_len) = self.mem.write_ptr_for(num).ok_or_else(|| {
108            std::io::Error::new(
109                std::io::ErrorKind::InvalidInput,
110                "aligned memory fill_write failed",
111            )
112        })?;
113
114        if self.zero_up_to_max_len && value == 0 {
115            // No action needed because up to `max_len` is zeroed and no shrinking is allowed
116        } else {
117            unsafe {
118                core::ptr::write_bytes(ptr, value, num);
119            }
120        }
121        unsafe {
122            self.mem.set_len(new_len);
123        }
124        Ok(())
125    }
126
127    /// Write a generic type T into the memory.
128    ///
129    /// # Safety
130    ///
131    /// Unsafe since it assumes that there is enough capacity.
132    pub unsafe fn write_unchecked<T: Pod>(&mut self, value: T) {
133        let pos = self.mem.len();
134        let new_len = pos.saturating_add(mem::size_of::<T>());
135        debug_assert!(new_len <= self.mem.capacity());
136        unsafe {
137            self.mem.write_ptr().cast::<T>().write_unaligned(value);
138            self.mem.set_len(new_len);
139        }
140    }
141
142    /// Write a slice of bytes into the memory.
143    ///
144    /// # Safety
145    ///
146    /// Unsafe since it assumes that there is enough capacity.
147    pub unsafe fn write_all_unchecked(&mut self, value: &[u8]) {
148        let pos = self.mem.len();
149        let new_len = pos.saturating_add(value.len());
150        debug_assert!(new_len <= self.mem.capacity());
151        core::ptr::copy_nonoverlapping(value.as_ptr(), self.mem.write_ptr(), value.len());
152        self.mem.set_len(new_len);
153    }
154}
155
156// Custom Clone impl is needed to ensure alignment. Derived clone would just
157// clone self.mem and there would be no guarantee that the clone allocation is
158// aligned.
159impl<const ALIGN: usize> Clone for AlignedMemory<ALIGN> {
160    fn clone(&self) -> Self {
161        AlignedMemory::from_slice(self.as_slice())
162    }
163}
164
165impl<const ALIGN: usize> std::io::Write for AlignedMemory<ALIGN> {
166    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
167        let (ptr, new_len) = self.mem.write_ptr_for(buf.len()).ok_or_else(|| {
168            std::io::Error::new(
169                std::io::ErrorKind::InvalidInput,
170                "aligned memory fill_write failed",
171            )
172        })?;
173        unsafe {
174            core::ptr::copy_nonoverlapping(buf.as_ptr(), ptr, buf.len());
175            self.mem.set_len(new_len);
176        }
177        Ok(buf.len())
178    }
179    fn flush(&mut self) -> std::io::Result<()> {
180        Ok(())
181    }
182}
183
184impl<const ALIGN: usize, T: AsRef<[u8]>> From<T> for AlignedMemory<ALIGN> {
185    fn from(bytes: T) -> Self {
186        AlignedMemory::from_slice(bytes.as_ref())
187    }
188}
189
190/// Returns true if `ptr` is aligned to `align`.
191pub fn is_memory_aligned(ptr: usize, align: usize) -> bool {
192    ptr.checked_rem(align)
193        .map(|remainder| remainder == 0)
194        .unwrap_or(false)
195}
196
197/// Provides backing storage for [`AlignedMemory`]. Allocates a block of bytes with the
198/// requested alignment, and can be increased in length up to the requested capacity.
199struct AlignedVec<const ALIGN: usize> {
200    ptr: NonNull<u8>,
201    length: usize,
202    capacity: usize,
203}
204
205impl<const ALIGN: usize> Drop for AlignedVec<ALIGN> {
206    fn drop(&mut self) {
207        if self.capacity == 0 {
208            return;
209        }
210        let ptr = self.ptr.as_ptr();
211        unsafe {
212            // SAFETY: Layout is checked on construction
213            let layout = Layout::from_size_align_unchecked(self.capacity, ALIGN);
214            dealloc(ptr, layout);
215        }
216    }
217}
218
219impl<const A: usize> std::fmt::Debug for AlignedVec<A> {
220    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221        f.debug_list().entries(self.as_slice()).finish()
222    }
223}
224
225impl<const A: usize> PartialEq for AlignedVec<A> {
226    fn eq(&self, other: &Self) -> bool {
227        self.as_slice() == other.as_slice()
228    }
229}
230
231impl<const A: usize> Eq for AlignedVec<A> {}
232
233impl<const ALIGN: usize> AlignedVec<ALIGN> {
234    /// Allocates a [`Vec<u8>`] with the requested alignment.
235    /// Ensure that the Vec is only dropped with the correct layout
236    ///
237    /// # Panics
238    /// Panics if the requested size is incompatible with the requested alignment or if allocation fails.
239    fn new(max_len: usize, zeroed: bool) -> Self {
240        assert!(ALIGN != 0, "Alignment must not be zero");
241        if max_len == 0 {
242            return Self::empty();
243        }
244        unsafe {
245            let layout = Layout::from_size_align(max_len, ALIGN).expect("invalid layout");
246            // SAFETY: Layout is non-zero, and allocation errors are handled
247            let ptr = if zeroed {
248                alloc_zeroed(layout)
249            } else {
250                alloc(layout)
251            };
252            if ptr.is_null() {
253                handle_alloc_error(layout);
254            }
255            Self {
256                ptr: NonNull::new(ptr).unwrap_or_else(|| handle_alloc_error(layout)),
257                length: 0,
258                capacity: max_len,
259            }
260        }
261    }
262
263    fn as_slice(&self) -> &[u8] {
264        unsafe { core::slice::from_raw_parts(self.ptr.as_ptr().cast_const(), self.length) }
265    }
266
267    fn as_slice_mut(&mut self) -> &mut [u8] {
268        unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.length) }
269    }
270
271    fn empty() -> Self {
272        Self {
273            // Create a dangling pointer
274            // FIXME: Use `Layout::dangling_ptr` once Rust 1.95.0 is released
275            ptr: NonNull::new(ALIGN as *mut u8).expect("alignment may not be zero"),
276            length: 0,
277            capacity: 0,
278        }
279    }
280
281    fn as_mut_ptr(&mut self) -> *mut u8 {
282        self.ptr.as_ptr()
283    }
284
285    /// Returns a pointer to the end of the current initialized length, i.e.
286    /// `mem.as_mut_ptr().mem(self.len())`.
287    /// Users must ensure that any writes to this pointer are in bounds of `capacity`
288    fn write_ptr(&mut self) -> *mut u8 {
289        unsafe { self.as_mut_ptr().add(self.len()) }
290    }
291
292    /// Similar to [`write_ptr`], but checks that there is room for the write.
293    /// Returns (pointer, new_length)
294    fn write_ptr_for(&mut self, bytes: usize) -> Option<(*mut u8, usize)> {
295        let ptr = self.write_ptr();
296        let new_len = self
297            .len()
298            .checked_add(bytes)
299            .filter(|l| *l <= self.capacity())?;
300        Some((ptr, new_len))
301    }
302
303    fn len(&self) -> usize {
304        self.length
305    }
306
307    fn capacity(&self) -> usize {
308        self.capacity
309    }
310
311    fn is_empty(&self) -> bool {
312        self.len() == 0
313    }
314
315    /// Set the length of the `AlignedVec`. The new length must be less than or equal to
316    /// the capacity, and the memory must be initialized up to that length.
317    /// The new length must not be less than the previous length.
318    unsafe fn set_len(&mut self, new_len: usize) {
319        debug_assert!(
320            new_len <= self.capacity,
321            "attempted to grow AlignedVec beyond capacity"
322        );
323        debug_assert!(new_len >= self.length, "attempted to shrink AlignedVec");
324        self.length = new_len;
325    }
326}
327
328/// `AlignedVec` is [`Send`] as `u8` is `Send` and the data behind the pointer is uniquely owned.
329unsafe impl<const N: usize> Send for AlignedVec<N> {}
330
331/// `AlignedVec` is [`Sync`] as `u8` is `Send` and the data behind the pointer is uniquely owned.
332unsafe impl<const N: usize> Sync for AlignedVec<N> {}
333
334#[allow(clippy::arithmetic_side_effects)]
335#[cfg(test)]
336mod tests {
337    use {super::*, std::io::Write};
338
339    fn do_test<const ALIGN: usize>() {
340        let mut aligned_memory = AlignedMemory::<ALIGN>::with_capacity(10);
341        let ptr = aligned_memory.mem.as_mut_ptr();
342        assert_eq!(
343            ptr.addr() & (ALIGN - 1),
344            0,
345            "memory is not correctly aligned"
346        );
347
348        assert_eq!(aligned_memory.write(&[42u8; 1]).unwrap(), 1);
349        assert_eq!(aligned_memory.write(&[42u8; 9]).unwrap(), 9);
350        assert_eq!(aligned_memory.as_slice(), &[42u8; 10]);
351        assert_eq!(aligned_memory.write(&[42u8; 0]).unwrap(), 0);
352        assert_eq!(aligned_memory.as_slice(), &[42u8; 10]);
353        aligned_memory.write(&[42u8; 1]).unwrap_err();
354        assert_eq!(aligned_memory.as_slice(), &[42u8; 10]);
355        aligned_memory.as_slice_mut().copy_from_slice(&[84u8; 10]);
356        assert_eq!(aligned_memory.as_slice(), &[84u8; 10]);
357
358        let mut aligned_memory = AlignedMemory::<ALIGN>::with_capacity_zeroed(10);
359        aligned_memory.fill_write(5, 0).unwrap();
360        aligned_memory.fill_write(2, 1).unwrap();
361        assert_eq!(aligned_memory.write(&[2u8; 3]).unwrap(), 3);
362        assert_eq!(aligned_memory.as_slice(), &[0, 0, 0, 0, 0, 1, 1, 2, 2, 2]);
363        aligned_memory.fill_write(1, 3).unwrap_err();
364        aligned_memory.write(&[4u8; 1]).unwrap_err();
365        assert_eq!(aligned_memory.as_slice(), &[0, 0, 0, 0, 0, 1, 1, 2, 2, 2]);
366
367        let aligned_memory = AlignedMemory::<ALIGN>::zero_filled(10);
368        assert_eq!(aligned_memory.len(), 10);
369        assert_eq!(aligned_memory.as_slice(), &[0u8; 10]);
370
371        let mut aligned_memory = AlignedMemory::<ALIGN>::with_capacity_zeroed(15);
372        unsafe {
373            aligned_memory.write_unchecked::<u8>(42);
374            assert_eq!(aligned_memory.len(), 1);
375            aligned_memory.write_unchecked::<u64>(0xCAFEBADDDEADCAFE);
376            assert_eq!(aligned_memory.len(), 9);
377            aligned_memory.fill_write(3, 0).unwrap();
378            aligned_memory.write_all_unchecked(b"foo");
379            assert_eq!(aligned_memory.len(), 15);
380        }
381        let mem = aligned_memory.as_slice();
382        assert_eq!(mem[0], 42);
383        assert_eq!(
384            unsafe {
385                core::ptr::read_unaligned::<u64>(mem[1..1 + mem::size_of::<u64>()].as_ptr().cast())
386            },
387            0xCAFEBADDDEADCAFE
388        );
389        assert_eq!(&mem[1 + mem::size_of::<u64>()..][..3], &[0, 0, 0]);
390        assert_eq!(&mem[1 + mem::size_of::<u64>() + 3..], b"foo");
391    }
392
393    #[test]
394    fn test_aligned_memory() {
395        do_test::<1>();
396        do_test::<16>();
397        do_test::<32768>();
398    }
399
400    #[cfg(debug_assertions)]
401    #[test]
402    #[should_panic(expected = "<= self.mem.capacity()")]
403    fn test_write_unchecked_debug_assert() {
404        let mut aligned_memory = AlignedMemory::<8>::with_capacity(15);
405        unsafe {
406            aligned_memory.write_unchecked::<u64>(42);
407            aligned_memory.write_unchecked::<u64>(24);
408        }
409    }
410
411    #[cfg(debug_assertions)]
412    #[test]
413    #[should_panic(expected = "<= self.mem.capacity()")]
414    fn test_write_all_unchecked_debug_assert() {
415        let mut aligned_memory = AlignedMemory::<8>::with_capacity(5);
416        unsafe {
417            aligned_memory.write_all_unchecked(b"foo");
418            aligned_memory.write_all_unchecked(b"bar");
419        }
420    }
421
422    const fn assert_send<T: Send>() {}
423    const fn assert_sync<T: Sync>() {}
424    const fn assert_unpin<T: Unpin>() {}
425    const _: () = assert_send::<AlignedMemory<8>>();
426    const _: () = assert_sync::<AlignedMemory<8>>();
427    const _: () = assert_unpin::<AlignedMemory<8>>();
428}