tensorgraph_sys/
boxed.rs

1use std::{
2    alloc::{Allocator, Global, Layout},
3    marker::PhantomData,
4    mem::{align_of_val, size_of_val, MaybeUninit},
5    ops::{Deref, DerefMut},
6};
7
8use crate::{
9    device::{DeviceAllocator, DevicePtr},
10    ptr::{NonNull, Ref},
11    zero::Zero,
12};
13
14/// Similar to [`std::boxed::Box`] but on device.
15pub struct Box<T: ?Sized, A: DeviceAllocator = Global> {
16    pub(crate) ptr: NonNull<T, A::Device>,
17    alloc: A,
18
19    /// signifies that Box owns the T
20    _marker: PhantomData<T>,
21}
22
23impl<T: ?Sized, A: DeviceAllocator> Box<T, A> {
24    pub fn into_raw_parts(self) -> (NonNull<T, A::Device>, A) {
25        let b = std::mem::ManuallyDrop::new(self);
26        (b.ptr, unsafe { std::ptr::read(&b.alloc) })
27    }
28
29    /// # Safety
30    /// Pointer must be a valid allocation within `alloc`
31    pub unsafe fn from_raw_parts(ptr: NonNull<T, A::Device>, alloc: A) -> Self {
32        Self {
33            ptr,
34            alloc,
35            _marker: PhantomData,
36        }
37    }
38
39    pub fn allocator(&self) -> &A {
40        &self.alloc
41    }
42}
43impl<T: ?Sized, A: Allocator> Box<T, A> {
44    pub fn into_std(self) -> std::boxed::Box<T, A> {
45        unsafe {
46            let (ptr, alloc) = self.into_raw_parts();
47            std::boxed::Box::from_raw_in(ptr.as_ptr(), alloc)
48        }
49    }
50}
51
52impl<T: ?Sized, A: DeviceAllocator> Deref for Box<T, A> {
53    type Target = Ref<T, A::Device>;
54
55    fn deref(&self) -> &Self::Target {
56        unsafe { Ref::from_ptr(self.ptr.as_ptr()) }
57    }
58}
59
60impl<T, A: DeviceAllocator> DerefMut for Box<[T], A> {
61    fn deref_mut(&mut self) -> &mut Self::Target {
62        unsafe { Ref::from_ptr_mut(self.ptr.as_ptr()) }
63    }
64}
65
66impl<T, A: DeviceAllocator> Box<[MaybeUninit<T>], A> {
67    /// # Safety
68    /// If this resize results in a shrink, the data that is lost must be already dropped
69    ///
70    /// # Panics
71    /// If the allocations cannot be resized
72    pub unsafe fn resize(&mut self, capacity: usize) {
73        let new = capacity;
74        let old = self.len();
75
76        let layout = Layout::new::<T>();
77        let old_layout = layout.repeat(old).unwrap().0;
78        let new_layout = layout.repeat(new).unwrap().0;
79
80        let data = match new.cmp(&old) {
81            std::cmp::Ordering::Greater => self
82                .alloc
83                .grow(self.ptr.cast(), old_layout, new_layout)
84                .unwrap()
85                .cast(),
86            std::cmp::Ordering::Less => self
87                .alloc
88                .shrink(self.ptr.cast(), old_layout, new_layout)
89                .unwrap()
90                .cast(),
91            std::cmp::Ordering::Equal => self.ptr.cast(),
92        };
93
94        self.ptr = NonNull::slice_from_raw_parts(data, new);
95    }
96
97    #[must_use]
98    /// Creates a new uninit slice with the given capacity
99    /// # Panics
100    /// If the allocation cannot be created
101    pub fn with_capacity(capacity: usize, alloc: A) -> Self {
102        unsafe {
103            let (layout, _) = Layout::new::<T>().repeat(capacity).unwrap();
104            let data = alloc.allocate(layout).unwrap().cast();
105            let buf = NonNull::slice_from_raw_parts(data, capacity);
106            Self::from_raw_parts(buf, alloc)
107        }
108    }
109}
110
111impl<T, A: DeviceAllocator> Box<[T], A> {
112    #[must_use]
113    /// Creates a new zeroed slice with the given capacity
114    /// # Panics
115    /// If the allocation cannot be created
116    pub fn zeroed(capacity: usize, alloc: A) -> Self
117    where
118        T: Zero,
119    {
120        unsafe {
121            let (layout, _) = Layout::new::<T>().repeat(capacity).unwrap();
122            let data = alloc.allocate_zeroed(layout).unwrap().cast();
123            let buf = NonNull::slice_from_raw_parts(data, capacity);
124            Self::from_raw_parts(buf, alloc)
125        }
126    }
127
128    pub fn into_uninit(self) -> Box<[MaybeUninit<T>], A> {
129        unsafe {
130            let (ptr, alloc) = self.into_raw_parts();
131            let (ptr, len) = ptr.to_raw_parts();
132            let ptr = NonNull::slice_from_raw_parts(ptr.cast(), len);
133            Box::from_raw_parts(ptr, alloc)
134        }
135    }
136}
137
138impl<T: ?Sized, A: DeviceAllocator> Drop for Box<T, A> {
139    fn drop(&mut self) {
140        unsafe {
141            let ref_ = &*(self.ptr.as_ptr().as_raw());
142            let size = size_of_val(ref_);
143            let align = align_of_val(ref_);
144            let layout = Layout::from_size_align_unchecked(size, align);
145            self.alloc.deallocate(self.ptr.cast(), layout);
146        }
147    }
148}