1#![allow(clippy::missing_safety_doc)]
2#![allow(clippy::missing_transmute_annotations)]
3
4mod arena_view;
5mod owned;
6
7pub use arena_view::*;
8pub use owned::*;
9
10use num_traits::AsPrimitive;
11use std::ffi::c_void;
12use std::fmt::Display;
13use tract_core::internal::*;
14use tract_data::itertools::Itertools;
15
16use crate::device::{DeviceBuffer, get_context};
17
18#[derive(Debug, Clone, Hash, PartialEq, Eq)]
21pub enum DeviceTensor {
22 Owned(Box<dyn OwnedDeviceTensor>),
23 ArenaView(DeviceArenaView),
24}
25
26impl DeviceTensor {
27 pub const SUPPORTED_DT: [DatumType; 11] = [
28 DatumType::Bool,
29 DatumType::F32,
30 DatumType::F16,
31 DatumType::I8,
32 DatumType::U8,
33 DatumType::I16,
34 DatumType::U16,
35 DatumType::I32,
36 DatumType::U32,
37 DatumType::I64,
38 DatumType::U64,
39 ];
40
41 pub fn tname(dt: DatumType) -> TractResult<&'static str> {
42 Ok(match dt {
43 DatumType::F32 => "f32",
44 DatumType::F16 => "f16",
45 DatumType::U8 => "u8",
46 DatumType::U16 => "u16",
47 DatumType::U32 => "u32",
48 DatumType::U64 => "u64",
49 DatumType::I8 => "i8",
50 DatumType::I16 => "i16",
51 DatumType::I32 => "i32",
52 DatumType::I64 => "i64",
53 DatumType::Bool => "bool",
54 _ => bail!("Unsupported dt {:?} for GPU Tensor", dt),
55 })
56 }
57
58 pub fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<DeviceTensor> {
60 Ok(DeviceTensor::Owned(get_context()?.uninitialized_device_tensor(shape, dt)?))
61 }
62
63 pub fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<DeviceTensor> {
64 Self::uninitialized_dt(T::datum_type(), shape)
65 }
66
67 pub fn uninitialized_exotic(exotic_fact: Box<dyn ExoticFact>) -> TractResult<DeviceTensor> {
68 Ok(DeviceTensor::Owned(get_context()?.uninitialized_device_exotic_tensor(exotic_fact)?))
69 }
70 pub fn from_shape<T: Copy + Datum>(shape: &[usize], data: &[T]) -> TractResult<DeviceTensor> {
72 Tensor::from_shape(shape, data)?.into_device()
73 }
74
75 pub fn is_supported_dt(dt: DatumType) -> bool {
76 Self::SUPPORTED_DT.contains(&dt)
77 }
78
79 #[inline]
81 pub fn datum_type(&self) -> DatumType {
82 match self {
83 Self::Owned(owned) => owned.datum_type(),
84 Self::ArenaView(view) => view.datum_type(),
85 }
86 }
87
88 #[inline]
90 pub fn rank(&self) -> usize {
91 self.shape().len()
92 }
93
94 #[inline]
96 pub fn shape(&self) -> &[usize] {
97 match self {
98 Self::Owned(t) => t.shape(),
99 Self::ArenaView(t) => t.shape(),
100 }
101 }
102
103 #[inline]
105 #[allow(clippy::len_without_is_empty)]
106 pub fn len(&self) -> usize {
107 match self {
108 Self::Owned(t) => t.len(),
109 Self::ArenaView(t) => t.len(),
110 }
111 }
112
113 #[inline]
115 pub fn strides(&self) -> &[isize] {
116 match self {
117 Self::Owned(t) => t.strides(),
118 Self::ArenaView(t) => t.strides(),
119 }
120 }
121
122 pub fn device_buffer(&self) -> &dyn DeviceBuffer {
124 match self {
125 Self::Owned(t) => t.device_buffer(),
126 Self::ArenaView(t) => t.device_buffer(),
127 }
128 }
129
130 pub fn buffer_offset<I: Copy + 'static>(&self) -> I
132 where
133 usize: AsPrimitive<I>,
134 {
135 match self {
136 Self::Owned(_) => 0.as_(),
137 Self::ArenaView(t) => t.buffer_offset(),
138 }
139 }
140
141 pub fn device_buffer_ptr(&self) -> *const c_void {
142 match self {
143 Self::Owned(t) => t.device_buffer().ptr(),
144 Self::ArenaView(t) => t.device_buffer().ptr(),
145 }
146 }
147
148 pub fn description(&self) -> String {
150 format!("|{},{:?}|", self.shape().iter().join(","), self.datum_type(),)
151 }
152
153 pub fn reshaped(&self, shape: TVec<usize>) -> TractResult<Self> {
155 match self {
156 Self::Owned(t) => Ok(t.reshaped(shape)?),
157 Self::ArenaView(t) => Ok(Self::ArenaView(t.reshaped(shape)?)),
158 }
159 }
160
161 pub fn restrided(&self, strides: TVec<isize>) -> TractResult<Self> {
162 match self {
163 Self::Owned(t) => Ok(t.restrided(strides)?),
164 Self::ArenaView(t) => Ok(Self::ArenaView(t.restrided(strides)?)),
165 }
166 }
167
168 pub fn into_tensor(self) -> Tensor {
173 let dt = self.datum_type();
174 let shape: TVec<usize> = self.shape().into();
175 Tensor::from_storage(dt, &shape, self)
176 }
177
178 pub fn to_host(&self) -> TractResult<Arc<Tensor>> {
181 get_context()?.synchronize()?;
182
183 Ok(match self {
184 Self::Owned(o) => o.to_host()?,
185 Self::ArenaView(v) => v.to_host()?.into(),
186 })
187 }
188}
189
190impl Display for DeviceTensor {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 match self {
193 Self::Owned(o) => o.fmt(f),
194 Self::ArenaView(v) => {
195 let content =
196 v.to_host().unwrap().dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
197 write!(f, "ArenaView: {{ {content} }}")
198 }
199 }
200 }
201}
202
203pub trait IntoDevice<T> {
204 fn into_device(self) -> TractResult<T>;
205}
206
207impl IntoDevice<DeviceTensor> for Tensor {
208 fn into_device(self) -> TractResult<DeviceTensor> {
209 Ok(DeviceTensor::Owned(get_context()?.tensor_to_device(self.into_tvalue())?))
210 }
211}
212
213impl IntoDevice<DeviceTensor> for Arc<Tensor> {
214 fn into_device(self) -> TractResult<DeviceTensor> {
215 Ok(DeviceTensor::Owned(get_context()?.tensor_to_device(self.into_tvalue())?))
216 }
217}
218
219impl TensorStorage for DeviceTensor {
220 fn byte_len(&self) -> usize {
221 self.len() * self.datum_type().size_of()
222 }
223
224 fn is_empty(&self) -> bool {
225 self.byte_len() == 0
226 }
227
228 fn deep_clone(&self) -> Box<dyn TensorStorage> {
229 Box::new(self.clone())
230 }
231
232 fn as_plain(&self) -> Option<&PlainStorage> {
233 None
234 }
235
236 fn as_plain_mut(&mut self) -> Option<&mut PlainStorage> {
237 None
238 }
239
240 fn into_plain(self: Box<Self>) -> Option<PlainStorage> {
241 None
242 }
243
244 fn dyn_hash(&self, _state: &mut dyn std::hash::Hasher) {
245 }
247
248 fn exotic_fact(&self, _shape: &[usize]) -> TractResult<Option<Box<dyn ExoticFact>>> {
249 bail!(
250 "DeviceTensor cannot reconstruct a DeviceFact: origin (FromHost/FromDevice) is not carried by storage"
251 )
252 }
253}
254
255impl From<DeviceArenaView> for DeviceTensor {
256 fn from(view: DeviceArenaView) -> Self {
257 Self::ArenaView(view)
258 }
259}
260
261pub trait DeviceTensorExt {
262 fn to_device_tensor(&self) -> TractResult<&DeviceTensor>;
263 fn as_device_tensor(&self) -> Option<&DeviceTensor>;
264}
265
266impl DeviceTensorExt for Tensor {
267 fn to_device_tensor(&self) -> TractResult<&DeviceTensor> {
268 self.try_storage_as::<DeviceTensor>()
269 }
270
271 fn as_device_tensor(&self) -> Option<&DeviceTensor> {
272 self.storage_as::<DeviceTensor>()
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_device_tensor() -> TractResult<()> {
282 let a = DeviceTensor::from_shape(&[1], &[0f32])?;
283 assert_eq!(a.to_host()?.try_as_plain()?.as_slice::<f32>()?, &[0.0]);
284 Ok(())
285 }
286}