tensorgraph_sys/ptr/
non_null.rs

1use std::{fmt::Debug, marker::PhantomData, ptr::Pointee};
2
3use crate::device::{Device, DevicePtr};
4
5/// Same as [`std::ptr::NonNull<T>`] but backed by a [`Device::Ptr`] instead of a raw pointer
6pub struct NonNull<T: ?Sized, D: Device> {
7    inner: std::ptr::NonNull<T>,
8    marker: PhantomData<D>,
9}
10
11impl<T: ?Sized, D: Device> Debug for NonNull<T, D> {
12    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13        f.debug_tuple("NonNull").field(&self.inner).finish()
14    }
15}
16
17impl<T: ?Sized, D: Device> Clone for NonNull<T, D> {
18    #[inline]
19    fn clone(&self) -> Self {
20        *self
21    }
22}
23
24impl<T: ?Sized, D: Device> Copy for NonNull<T, D> {}
25
26impl<T: ?Sized, D: Device> NonNull<T, D> {
27    pub fn new(ptr: D::Ptr<T>) -> Option<Self> {
28        let inner = std::ptr::NonNull::new(ptr.as_raw())?;
29        Some(Self {
30            inner,
31            marker: PhantomData,
32        })
33    }
34
35    /// # Safety
36    /// ptr must not be null
37    pub unsafe fn new_unchecked(ptr: D::Ptr<T>) -> Self {
38        let inner = std::ptr::NonNull::new_unchecked(ptr.as_raw());
39        Self {
40            inner,
41            marker: PhantomData,
42        }
43    }
44
45    #[must_use]
46    pub fn as_ptr(self) -> D::Ptr<T> {
47        D::Ptr::from_raw(self.inner.as_ptr())
48    }
49
50    #[must_use]
51    pub fn cast<U>(self) -> NonNull<U, D> {
52        let Self { inner, marker } = self;
53        NonNull {
54            inner: inner.cast(),
55            marker,
56        }
57    }
58
59    #[must_use]
60    pub fn to_raw_parts(self) -> (NonNull<(), D>, <T as Pointee>::Metadata) {
61        let (ptr, meta) = self.inner.as_ptr().to_raw_parts();
62        let ptr = D::Ptr::from_raw(ptr);
63        let data = unsafe { NonNull::new_unchecked(ptr) };
64        (data, meta)
65    }
66}
67
68impl<T, D: Device> NonNull<[T], D> {
69    #[must_use]
70    pub fn slice_from_raw_parts(data: NonNull<T, D>, len: usize) -> Self {
71        let NonNull { inner, marker } = data;
72        let inner = std::ptr::NonNull::slice_from_raw_parts(inner, len);
73        Self { inner, marker }
74    }
75}