Skip to main content

slop_alloc/
slice.rs

1use std::{
2    alloc::Layout,
3    marker::PhantomData,
4    ops::{
5        Deref, DerefMut, Index, IndexMut, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo,
6        RangeToInclusive,
7    },
8};
9
10use crate::{
11    backend::CpuBackend,
12    mem::{CopyDirection, CopyError, DeviceMemory},
13    Allocator, Init,
14};
15
16/// A slice of data associated with a specific allocator type.
17///
18/// This type is enssentially a wrapper around a slice and has an indicator for the type of the
19/// allocator to induicate where the memory resides but it
20#[repr(transparent)]
21pub struct Slice<T, A = CpuBackend> {
22    allocator: PhantomData<A>,
23    slice: [T],
24}
25
26impl<T, A: Allocator> Slice<T, A> {
27    #[inline]
28    pub const fn len(&self) -> usize {
29        self.slice.len()
30    }
31
32    #[inline]
33    pub const fn is_empty(&self) -> bool {
34        self.slice.is_empty()
35    }
36
37    #[inline]
38    pub fn as_ptr(&self) -> *const T {
39        self.slice.as_ptr()
40    }
41
42    #[inline]
43    pub fn as_mut_ptr(&mut self) -> *mut T {
44        self.slice.as_mut_ptr()
45    }
46
47    #[inline(always)]
48    pub(crate) unsafe fn from_slice(src: &[T]) -> &Self {
49        &*(src as *const [T] as *const Self)
50    }
51
52    /// # Safety
53    #[inline]
54    pub unsafe fn from_raw_parts<'a>(data: *const T, len: usize) -> &'a Self {
55        Self::from_slice(std::slice::from_raw_parts(data, len))
56    }
57
58    #[inline(always)]
59    pub(crate) unsafe fn from_slice_mut(src: &mut [T]) -> &mut Self {
60        &mut *(src as *mut [T] as *mut Self)
61    }
62
63    /// # Safety
64    pub unsafe fn from_raw_parts_mut<'a>(data: *mut T, len: usize) -> &'a mut Self {
65        Self::from_slice_mut(std::slice::from_raw_parts_mut(data, len))
66    }
67
68    #[inline]
69    pub fn split_at_mut(&mut self, mid: usize) -> (&mut Self, &mut Self) {
70        let (left, right) = self.slice.split_at_mut(mid);
71        unsafe { (Self::from_slice_mut(left), Self::from_slice_mut(right)) }
72    }
73
74    #[inline]
75    pub fn split_at(&self, mid: usize) -> (&Self, &Self) {
76        let (left, right) = self.slice.split_at(mid);
77        unsafe { (Self::from_slice(left), Self::from_slice(right)) }
78    }
79
80    /// Copies all elements from `src` into `self`, using `copy_nonoverlapping`.
81    ///
82    /// The length of `src` must be the same as `self`.
83    ///
84    /// # Panics
85    ///
86    /// This function will panic if the two slices have different lengths or if cudaMalloc
87    /// returned an error.
88    ///
89    /// # Safety
90    /// This operation is potentially asynchronous. The caller must insure the memory of the source
91    /// is valid for the duration of the operation.
92    #[inline]
93    #[track_caller]
94    pub unsafe fn copy_from_slice(
95        &mut self,
96        src: &Slice<T, A>,
97        allocator: &A,
98    ) -> Result<(), CopyError>
99    where
100        A: DeviceMemory,
101    {
102        // The panic code path was put into a cold function to not bloat the
103        // call site.
104        #[inline(never)]
105        #[cold]
106        #[track_caller]
107        fn len_mismatch_fail(dst_len: usize, src_len: usize) -> ! {
108            panic!(
109                "source slice length ({src_len}) does not match destination slice length ({dst_len})",
110            );
111        }
112
113        if self.len() != src.len() {
114            len_mismatch_fail(self.len(), src.len());
115        }
116
117        let layout = Layout::array::<T>(src.len()).unwrap();
118
119        unsafe {
120            allocator.copy_nonoverlapping(
121                src.as_ptr() as *const u8,
122                self.as_mut_ptr() as *mut u8,
123                layout.size(),
124                CopyDirection::DeviceToDevice,
125            )
126        }
127    }
128}
129
130macro_rules! impl_index {
131    ($($t:ty)*) => {
132        $(
133            impl<T, A: Allocator> Index<$t> for Slice<T, A>
134            {
135                type Output = Slice<T, A>;
136
137                fn index(&self, index: $t) -> &Self {
138                    unsafe { Slice::from_slice(self.slice.index(index)) }
139                }
140            }
141
142            impl<T, A: Allocator> IndexMut<$t> for Slice<T, A>
143            {
144                fn index_mut(&mut self, index: $t) -> &mut Self {
145                    unsafe { Slice::from_slice_mut( self.slice.index_mut(index)) }
146                }
147            }
148        )*
149    }
150}
151
152impl_index! {
153    Range<usize>
154    RangeFull
155    RangeFrom<usize>
156    RangeInclusive<usize>
157    RangeTo<usize>
158    RangeToInclusive<usize>
159}
160
161impl<T, A: Allocator> Index<usize> for Slice<T, A> {
162    type Output = Init<T, A>;
163
164    #[inline]
165    fn index(&self, index: usize) -> &Self::Output {
166        let ptr = self.slice.index(index) as *const T as *const Init<T, A>;
167        unsafe { ptr.as_ref().unwrap() }
168    }
169}
170
171impl<T, A: Allocator> IndexMut<usize> for Slice<T, A> {
172    #[inline]
173    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
174        let ptr = self.slice.index_mut(index) as *mut T as *mut Init<T, A>;
175        unsafe { ptr.as_mut().unwrap() }
176    }
177}
178
179impl<T> Slice<T, CpuBackend> {
180    #[inline]
181    pub fn to_vec(&self) -> Vec<T>
182    where
183        T: Clone,
184    {
185        self.slice.to_vec()
186    }
187}
188
189impl<T> Deref for Slice<T, CpuBackend> {
190    type Target = [T];
191
192    fn deref(&self) -> &Self::Target {
193        unsafe { std::slice::from_raw_parts(self.as_ptr(), self.len()) }
194    }
195}
196
197impl<T> DerefMut for Slice<T, CpuBackend> {
198    fn deref_mut(&mut self) -> &mut Self::Target {
199        unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.len()) }
200    }
201}
202
203impl<T: PartialEq> PartialEq for Slice<T, CpuBackend> {
204    fn eq(&self, other: &Self) -> bool {
205        self.slice == other.slice
206    }
207}
208
209impl<T: Eq> Eq for Slice<T, CpuBackend> {}