tensorgraph_sys/
vec.rs

1use std::{
2    alloc::{Allocator, Global},
3    borrow::Borrow,
4    mem::{ManuallyDrop, MaybeUninit},
5    ops::{Deref, DerefMut},
6};
7
8use crate::{
9    boxed::Box,
10    device::{cpu::Cpu, DefaultDeviceAllocator, Device, DeviceAllocator, DevicePtr},
11    ptr::{NonNull, Ref},
12    zero::Zero,
13};
14
15/// Same as [`std::vec::Vec`] but using device allocators rather than host allocators.
16/// This allows you to have owned buffers on GPUs and CPUs using a single data structure.
17pub struct Vec<T, A: DeviceAllocator = Global> {
18    buf: Box<[MaybeUninit<T>], A>,
19    len: usize,
20}
21
22/// A [`Vec`] that uses the default allocator for the device
23pub type DefaultVec<T, D = Cpu> = Vec<T, <D as DefaultDeviceAllocator>::Alloc>;
24
25impl<T, A: DeviceAllocator> Drop for Vec<T, A> {
26    fn drop(&mut self) {
27        unsafe {
28            // drop the data
29            if std::mem::needs_drop::<T>() {
30                // we are on the CPU
31                if A::Device::IS_CPU {
32                    let slice = &mut *(self.buf.ptr.as_ptr().as_raw());
33                    let slice = &mut slice[..self.len];
34                    for i in slice {
35                        std::ptr::drop_in_place(i);
36                    }
37                } else {
38                    panic!("drop types should not be initialised outside of the CPU")
39                }
40            }
41        }
42    }
43}
44
45impl<T: Copy, A: DeviceAllocator + Clone> Clone for Vec<T, A> {
46    fn clone(&self) -> Self {
47        let slice = &*self;
48        unsafe {
49            let mut vec = Self::with_capacity_in(slice.len(), self.buf.allocator().clone());
50            vec.space_capacity_mut().init_from_slice(slice);
51            vec.set_len(slice.len());
52            vec
53        }
54    }
55}
56
57impl<T, A: DeviceAllocator> Vec<T, A> {
58    pub fn from_box(b: Box<[T], A>) -> Self {
59        let len = b.len();
60        unsafe { Self::from_raw_parts(b.into_uninit(), len) }
61    }
62
63    pub fn zeroed_in(len: usize, alloc: A) -> Self
64    where
65        T: Zero,
66    {
67        Self::from_box(Box::zeroed(len, alloc))
68    }
69
70    #[must_use]
71    pub fn zeroed(len: usize) -> Self
72    where
73        T: Zero,
74        A: Default,
75    {
76        Self::zeroed_in(len, A::default())
77    }
78
79    pub fn copy_from_host_in(slice: &[T], alloc: A) -> Self
80    where
81        T: Copy,
82    {
83        unsafe {
84            let mut vec = Self::with_capacity_in(slice.len(), alloc);
85            vec.space_capacity_mut().init_from_host(slice);
86            vec.set_len(slice.len());
87            vec
88        }
89    }
90
91    pub fn copy_from_host(slice: &[T]) -> Self
92    where
93        T: Copy,
94        A: Default,
95    {
96        Self::copy_from_host_in(slice, A::default())
97    }
98
99    /// # Safety
100    /// `buf` must be a valid allocation in `device`, and `len` items must be initialised
101    pub unsafe fn from_raw_parts(buf: Box<[MaybeUninit<T>], A>, len: usize) -> Self {
102        Self { buf, len }
103    }
104
105    pub fn into_raw_parts(self) -> (Box<[MaybeUninit<T>], A>, usize) {
106        let v = ManuallyDrop::new(self);
107        unsafe { (std::ptr::read(&v.buf), v.len) }
108    }
109
110    #[must_use]
111    pub fn with_capacity(capacity: usize) -> Self
112    where
113        A: Default,
114    {
115        Self::with_capacity_in(capacity, A::default())
116    }
117
118    pub fn with_capacity_in(capacity: usize, alloc: A) -> Self {
119        let buf = Box::with_capacity(capacity, alloc);
120        unsafe { Self::from_raw_parts(buf, 0) }
121    }
122
123    pub fn len(&self) -> usize {
124        self.len
125    }
126
127    pub fn is_empty(&self) -> bool {
128        self.len == 0
129    }
130
131    pub fn capacity(&self) -> usize {
132        self.buf.len()
133    }
134
135    pub fn space_capacity_mut(&mut self) -> &mut Ref<[MaybeUninit<T>], A::Device> {
136        &mut self.buf.deref_mut()[self.len..]
137    }
138
139    /// # Safety
140    /// If len is smaller than the current length, the caller must ensure they drop the values.
141    /// If the len is greater than the current length, the caller must ensure they have initialised those values
142    pub unsafe fn set_len(&mut self, len: usize) {
143        self.len = len;
144    }
145
146    unsafe fn ensure(&mut self, capacity: usize) {
147        let old = self.capacity();
148        if capacity > old {
149            let new = match capacity {
150                1..=4 => 4,
151                n => n.next_power_of_two(),
152            };
153
154            self.buf.resize(new);
155        }
156    }
157
158    pub fn push(&mut self, val: T) {
159        unsafe {
160            self.ensure(self.len + 1);
161            self.buf.ptr.cast::<T>().as_ptr().add(self.len).write(val);
162            self.len += 1;
163        }
164    }
165}
166
167impl<T, A: Allocator> From<std::vec::Vec<T, A>> for Vec<T, A> {
168    fn from(v: std::vec::Vec<T, A>) -> Self {
169        unsafe {
170            let (ptr, len, cap, alloc) = v.into_raw_parts_with_alloc();
171            let data = NonNull::new_unchecked(ptr.cast());
172            let ptr = NonNull::slice_from_raw_parts(data, cap);
173            let buf = Box::from_raw_parts(ptr, alloc);
174            Self::from_raw_parts(buf, len)
175        }
176    }
177}
178
179impl<T, A: Allocator> From<Vec<T, A>> for std::vec::Vec<T, A> {
180    fn from(v: Vec<T, A>) -> Self {
181        unsafe {
182            let (buf, len) = v.into_raw_parts();
183            let (ptr, alloc) = buf.into_raw_parts();
184            let (ptr, cap) = ptr.as_ptr().to_raw_parts();
185            Self::from_raw_parts_in(ptr.cast(), len, cap, alloc)
186        }
187    }
188}
189
190impl<T, A: Allocator> Vec<T, A> {
191    pub fn into_std(self) -> std::vec::Vec<T, A> {
192        self.into()
193    }
194}
195
196impl<T, A: DeviceAllocator> Deref for Vec<T, A> {
197    type Target = Ref<[T], A::Device>;
198
199    fn deref(&self) -> &Self::Target {
200        unsafe { self.buf.deref()[..self.len()].assume_init() }
201    }
202}
203
204impl<T, A: DeviceAllocator> DerefMut for Vec<T, A> {
205    fn deref_mut(&mut self) -> &mut Self::Target {
206        unsafe { self.buf.deref_mut()[..self.len].assume_init_mut() }
207    }
208}
209
210impl<T, A: DeviceAllocator> Borrow<Ref<[T], A::Device>> for Vec<T, A> {
211    fn borrow(&self) -> &Ref<[T], A::Device> {
212        self
213    }
214}
215
216impl<T, A: DeviceAllocator> AsRef<Ref<[T], A::Device>> for Vec<T, A> {
217    fn as_ref(&self) -> &Ref<[T], A::Device> {
218        self
219    }
220}
221
222impl<T, A: DeviceAllocator<Device = Cpu>> AsRef<[T]> for Vec<T, A> {
223    fn as_ref(&self) -> &[T] {
224        self
225    }
226}
227
228impl<T, A: DeviceAllocator> AsMut<Ref<[T], A::Device>> for Vec<T, A> {
229    fn as_mut(&mut self) -> &mut Ref<[T], A::Device> {
230        self
231    }
232}
233
234impl<T, A: DeviceAllocator<Device = Cpu>> AsMut<[T]> for Vec<T, A> {
235    fn as_mut(&mut self) -> &mut [T] {
236        self
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::Vec;
243
244    #[test]
245    fn push() {
246        let mut v = Vec::<_>::with_capacity(0);
247
248        assert_eq!(v.capacity(), 0);
249        v.push(0);
250        assert_eq!(v.capacity(), 4);
251        v.push(1);
252        assert_eq!(v.capacity(), 4);
253        v.push(2);
254        assert_eq!(v.capacity(), 4);
255        v.push(3);
256        assert_eq!(v.capacity(), 4);
257        v.push(4);
258        assert_eq!(v.capacity(), 8);
259    }
260
261    #[test]
262    fn convert() {
263        let mut v1 = Vec::with_capacity(0);
264
265        v1.push(0);
266        v1.push(1);
267        v1.push(2);
268        v1.push(3);
269        v1.push(4);
270
271        let v2 = vec![0, 1, 2, 3, 4];
272
273        assert_eq!(&**v1, v2.as_slice());
274        assert_eq!(v1.into_std(), v2);
275    }
276}