tensorgraph_sys/
device.rs

1//! Provides trait defintion and implementations of a [`Device`]
2
3use std::alloc::Layout;
4
5use crate::ptr::{NonNull, Ref};
6
7pub mod cpu;
8
9#[cfg(feature = "cuda")]
10pub mod cuda;
11
12/// Represents a physical device that can host memory.
13/// For example: [`cpu::Cpu`], [`cuda::Cuda`]
14pub trait Device: Sized {
15    #![allow(clippy::missing_safety_doc)]
16
17    type Ptr<T: ?Sized>: DevicePtr<T>;
18    const IS_CPU: bool = false;
19
20    fn copy_from_host<T: Copy>(from: &[T], to: &mut Ref<[T], Self>);
21    fn copy_to_host<T: Copy>(from: &Ref<[T], Self>, to: &mut [T]);
22    fn copy<T: Copy>(from: &Ref<[T], Self>, to: &mut Ref<[T], Self>);
23}
24
25/// Defines the default allocator for a device.
26/// For instance, The default allocator for [`cpu::Cpu`] is [`std::alloc::Global`]
27pub trait DefaultDeviceAllocator: Device {
28    type Alloc: DeviceAllocator<Device = Self> + Default;
29}
30
31/// Represents a type safe device-based pointer.
32/// For the CPU, this will be just `*mut T`.
33pub trait DevicePtr<T: ?Sized>: Copy {
34    fn as_raw(self) -> *mut T;
35    fn from_raw(ptr: *mut T) -> Self;
36
37    /// # Safety
38    /// Pointer must be valid and aligned
39    unsafe fn write(self, val: T)
40    where
41        T: Sized;
42
43    /// # Safety
44    /// Offset should not overflow isize.
45    /// Resulting pointer should not overflow usize.
46    /// Resulting pointer must be in bounds of an allocated buffer.
47    #[must_use]
48    unsafe fn add(self, count: usize) -> Self
49    where
50        T: Sized,
51    {
52        Self::from_raw(self.as_raw().add(count))
53    }
54
55    /// # Safety
56    /// Resulting pointer should not underflow usize.
57    /// Resulting pointer must be in bounds of an allocated buffer.
58    #[must_use]
59    unsafe fn sub(self, count: usize) -> Self
60    where
61        T: Sized,
62    {
63        Self::from_raw(self.as_raw().sub(count))
64    }
65
66    /// # Safety
67    /// Resulting pointer should not overflow usize.
68    /// Resulting pointer must be in bounds of an allocated buffer.
69    #[must_use]
70    unsafe fn offset(self, count: isize) -> Self
71    where
72        T: Sized,
73    {
74        Self::from_raw(self.as_raw().offset(count))
75    }
76}
77
78/// An Allocator in a specific device.
79/// All [`std::alloc::Allocator`]s are [`DeviceAllocator<Device=cpu::Cpu>`]
80pub trait DeviceAllocator {
81    #![allow(clippy::missing_safety_doc)]
82
83    /// Error returned when failing to allocate
84    type AllocError: std::error::Error;
85    type Device: Device;
86
87    /// Create a new allocation
88    ///
89    /// # Errors
90    /// If the device fails, is not ready, or the allocation was invalid
91    fn allocate(&self, layout: Layout) -> Result<NonNull<[u8], Self::Device>, Self::AllocError>;
92
93    /// Create a new allocation with zeroes
94    ///
95    /// # Errors
96    /// If the device fails, is not ready, or the allocation was invalid
97    fn allocate_zeroed(
98        &self,
99        layout: Layout,
100    ) -> Result<NonNull<[u8], Self::Device>, Self::AllocError>;
101
102    unsafe fn deallocate(&self, ptr: NonNull<u8, Self::Device>, layout: Layout);
103
104    /// Grows an allocation
105    ///
106    /// # Errors
107    /// If the device fails, is not ready, or the allocation was invalid
108    unsafe fn grow(
109        &self,
110        ptr: NonNull<u8, Self::Device>,
111        old_layout: Layout,
112        new_layout: Layout,
113    ) -> Result<NonNull<[u8], Self::Device>, Self::AllocError>;
114
115    /// Grows an allocation with zeroes
116    ///
117    /// # Errors
118    /// If the device fails, is not ready, or the allocation was invalid
119    unsafe fn grow_zeroed(
120        &self,
121        ptr: NonNull<u8, Self::Device>,
122        old_layout: Layout,
123        new_layout: Layout,
124    ) -> Result<NonNull<[u8], Self::Device>, Self::AllocError>;
125
126    /// Shrinks an allocation
127    ///
128    /// # Errors
129    /// If the device fails, is not ready, or the allocation was invalid
130    unsafe fn shrink(
131        &self,
132        ptr: NonNull<u8, Self::Device>,
133        old_layout: Layout,
134        new_layout: Layout,
135    ) -> Result<NonNull<[u8], Self::Device>, Self::AllocError>;
136}