tensorgraph_sys/
device.rs1use std::alloc::Layout;
4
5use crate::ptr::{NonNull, Ref};
6
7pub mod cpu;
8
9#[cfg(feature = "cuda")]
10pub mod cuda;
11
12pub trait Device: Sized {
15 #![allow(clippy::missing_safety_doc)]
16
17 type Ptr<T: ?Sized>: DevicePtr<T>;
18 const IS_CPU: bool = false;
19
20 fn copy_from_host<T: Copy>(from: &[T], to: &mut Ref<[T], Self>);
21 fn copy_to_host<T: Copy>(from: &Ref<[T], Self>, to: &mut [T]);
22 fn copy<T: Copy>(from: &Ref<[T], Self>, to: &mut Ref<[T], Self>);
23}
24
25pub trait DefaultDeviceAllocator: Device {
28 type Alloc: DeviceAllocator<Device = Self> + Default;
29}
30
31pub trait DevicePtr<T: ?Sized>: Copy {
34 fn as_raw(self) -> *mut T;
35 fn from_raw(ptr: *mut T) -> Self;
36
37 unsafe fn write(self, val: T)
40 where
41 T: Sized;
42
43 #[must_use]
48 unsafe fn add(self, count: usize) -> Self
49 where
50 T: Sized,
51 {
52 Self::from_raw(self.as_raw().add(count))
53 }
54
55 #[must_use]
59 unsafe fn sub(self, count: usize) -> Self
60 where
61 T: Sized,
62 {
63 Self::from_raw(self.as_raw().sub(count))
64 }
65
66 #[must_use]
70 unsafe fn offset(self, count: isize) -> Self
71 where
72 T: Sized,
73 {
74 Self::from_raw(self.as_raw().offset(count))
75 }
76}
77
78pub trait DeviceAllocator {
81 #![allow(clippy::missing_safety_doc)]
82
83 type AllocError: std::error::Error;
85 type Device: Device;
86
87 fn allocate(&self, layout: Layout) -> Result<NonNull<[u8], Self::Device>, Self::AllocError>;
92
93 fn allocate_zeroed(
98 &self,
99 layout: Layout,
100 ) -> Result<NonNull<[u8], Self::Device>, Self::AllocError>;
101
102 unsafe fn deallocate(&self, ptr: NonNull<u8, Self::Device>, layout: Layout);
103
104 unsafe fn grow(
109 &self,
110 ptr: NonNull<u8, Self::Device>,
111 old_layout: Layout,
112 new_layout: Layout,
113 ) -> Result<NonNull<[u8], Self::Device>, Self::AllocError>;
114
115 unsafe fn grow_zeroed(
120 &self,
121 ptr: NonNull<u8, Self::Device>,
122 old_layout: Layout,
123 new_layout: Layout,
124 ) -> Result<NonNull<[u8], Self::Device>, Self::AllocError>;
125
126 unsafe fn shrink(
131 &self,
132 ptr: NonNull<u8, Self::Device>,
133 old_layout: Layout,
134 new_layout: Layout,
135 ) -> Result<NonNull<[u8], Self::Device>, Self::AllocError>;
136}