1#![allow(clippy::missing_safety_doc)]
2#![allow(clippy::missing_transmute_annotations)]
3
4mod arena_view;
5mod owned;
6
7pub use arena_view::*;
8pub use owned::*;
9
10use num_traits::AsPrimitive;
11use std::ffi::c_void;
12use std::fmt::Display;
13use tract_core::internal::*;
14use tract_data::itertools::Itertools;
15
16use crate::device::{get_context, DeviceBuffer};
17
18#[derive(Debug, Clone, Hash)]
21pub enum DeviceTensor {
22 Owned(OwnedDeviceTensor),
23 ArenaView(DeviceArenaView),
24}
25
26impl DeviceTensor {
27 pub const SUPPORTED_DT: [DatumType; 12] = [
28 DatumType::Bool,
29 DatumType::F32,
30 DatumType::F16,
31 DatumType::I8,
32 DatumType::U8,
33 DatumType::I16,
34 DatumType::U16,
35 DatumType::I32,
36 DatumType::U32,
37 DatumType::I64,
38 DatumType::U64,
39 DatumType::Opaque,
40 ];
41
42 pub fn tname(dt: DatumType) -> TractResult<&'static str> {
43 Ok(match dt {
44 DatumType::F32 => "f32",
45 DatumType::F16 => "f16",
46 DatumType::U8 => "u8",
47 DatumType::U16 => "u16",
48 DatumType::U32 => "u32",
49 DatumType::U64 => "u64",
50 DatumType::I8 => "i8",
51 DatumType::I16 => "i16",
52 DatumType::I32 => "i32",
53 DatumType::I64 => "i64",
54 DatumType::Bool => "bool",
55 DatumType::Opaque => "opaque",
56 _ => bail!("Unsupport dt {:?} for GPU Tensor", dt),
57 })
58 }
59
60 pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<DeviceTensor> {
62 Tensor::uninitialized_dt(dt, shape)?.into_device()
63 }
64
65 pub unsafe fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<DeviceTensor> {
66 Self::uninitialized_dt(T::datum_type(), shape)
67 }
68
69 pub fn from_shape<T: Copy + Datum>(shape: &[usize], data: &[T]) -> TractResult<DeviceTensor> {
71 Tensor::from_shape(shape, data)?.into_device()
72 }
73
74 pub fn is_supported_dt(dt: DatumType) -> bool {
75 Self::SUPPORTED_DT.contains(&dt)
76 }
77
78 #[inline]
80 pub fn datum_type(&self) -> DatumType {
81 match self {
82 Self::Owned(OwnedDeviceTensor { inner, .. }) => inner.datum_type(),
83 Self::ArenaView(view) => view.datum_type(),
84 }
85 }
86
87 #[inline]
89 pub fn rank(&self) -> usize {
90 self.shape().len()
91 }
92
93 #[inline]
95 pub fn shape(&self) -> &[usize] {
96 match self {
97 Self::Owned(t) => t.shape(),
98 Self::ArenaView(t) => t.shape(),
99 }
100 }
101
102 #[inline]
104 #[allow(clippy::len_without_is_empty)]
105 pub fn len(&self) -> usize {
106 match self {
107 Self::Owned(t) => t.len(),
108 Self::ArenaView(t) => t.len(),
109 }
110 }
111
112 #[inline]
114 pub fn strides(&self) -> &[isize] {
115 match self {
116 Self::Owned(t) => t.strides(),
117 Self::ArenaView(t) => t.strides(),
118 }
119 }
120
121 pub fn device_buffer(&self) -> &dyn DeviceBuffer {
123 match self {
124 Self::Owned(t) => t.device_buffer(),
125 Self::ArenaView(t) => t.device_buffer(),
126 }
127 }
128
129 pub fn buffer_offset<I: Copy + 'static>(&self) -> I
131 where
132 usize: AsPrimitive<I>,
133 {
134 match self {
135 Self::Owned(t) => t.buffer_offset(),
136 Self::ArenaView(t) => t.buffer_offset(),
137 }
138 }
139
140 pub fn device_buffer_ptr(&self) -> *const c_void {
141 match self {
142 Self::Owned(t) => t.device_buffer_ptr(),
143 Self::ArenaView(t) => t.device_buffer_ptr(),
144 }
145 }
146
147 #[inline]
149 pub fn view(&self) -> TensorView {
150 match self {
151 Self::Owned(t) => t.view(),
152 Self::ArenaView(t) => t.view(),
153 }
154 }
155
156 pub fn description(&self) -> String {
158 format!("|{},{:?}|", self.shape().iter().join(","), self.datum_type(),)
159 }
160
161 pub fn reshaped(&self, shape: impl Into<TVec<usize>>) -> TractResult<Self> {
163 match self {
164 Self::Owned(t) => Ok(Self::Owned(t.reshaped(shape)?)),
165 Self::ArenaView(t) => Ok(Self::ArenaView(t.reshaped(shape)?)),
166 }
167 }
168
169 pub fn restrided(&self, strides: impl Into<TVec<isize>>) -> TractResult<Self> {
170 match self {
171 Self::Owned(t) => Ok(Self::Owned(t.restrided(strides)?)),
172 Self::ArenaView(t) => Ok(Self::ArenaView(t.restrided(strides)?)),
173 }
174 }
175
176 pub fn into_opaque_tensor(self) -> Tensor {
178 tensor0::<Opaque>(self.into())
179 }
180
181 pub fn to_host(&self) -> TractResult<Arc<Tensor>> {
184 get_context()?.synchronize()?;
185
186 Ok(match self {
187 Self::Owned(o) => o
188 .inner
189 .as_arc_tensor()
190 .cloned()
191 .unwrap_or_else(|| o.inner.clone().into_tensor().into_arc_tensor()),
192 Self::ArenaView(v) => v.clone().into_tensor().into(),
193 })
194 }
195}
196
197impl Display for DeviceTensor {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 match self {
200 Self::Owned(o) => match &o.inner {
201 DValue::Natural(t) => {
202 let content = t.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
203 write!(f, "Owned: {{ {content} }}")
204 }
205 DValue::Reshaped { t, shape, .. } => {
206 let content = t.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
207 write!(f, "Owned,Reshaped: {:?} - {{ {content} }}", shape)
208 }
209 },
210 Self::ArenaView(v) => {
211 let content = v
212 .clone()
213 .into_tensor()
214 .dump(false)
215 .unwrap_or_else(|e| format!("Error : {e:?}"));
216 write!(f, "ArenaView: {{ {content} }}")
217 }
218 }
219 }
220}
221
222pub trait IntoDevice<T> {
223 fn into_device(self) -> TractResult<T>;
224}
225
226impl IntoDevice<DeviceTensor> for Tensor {
227 fn into_device(self) -> TractResult<DeviceTensor> {
228 Ok(DeviceTensor::Owned(OwnedDeviceTensor::from_tensor(self)?))
229 }
230}
231
232impl IntoDevice<DeviceTensor> for Arc<Tensor> {
233 fn into_device(self) -> TractResult<DeviceTensor> {
234 Ok(DeviceTensor::Owned(OwnedDeviceTensor::from_tensor(self)?))
235 }
236}
237
238impl From<DeviceTensor> for Opaque {
239 fn from(value: DeviceTensor) -> Self {
240 Opaque(Arc::new(value))
241 }
242}
243
244impl From<DeviceArenaView> for DeviceTensor {
245 fn from(view: DeviceArenaView) -> Self {
246 Self::ArenaView(view)
247 }
248}
249
250impl OpaquePayload for DeviceTensor {
251 fn same_as(&self, other: &dyn OpaquePayload) -> bool {
252 other
253 .downcast_ref::<Self>()
254 .is_some_and(|other| self.device_buffer_ptr() == other.device_buffer_ptr())
255 }
256
257 fn clarify_to_tensor(&self) -> TractResult<Option<Arc<Tensor>>> {
258 Ok(Some(self.to_host()?))
259 }
260}
261
262pub trait DeviceTensorExt {
263 fn to_device_tensor(&self) -> TractResult<&DeviceTensor>;
264 fn as_device_tensor(&self) -> Option<&DeviceTensor>;
265 fn to_device_tensor_mut(&mut self) -> TractResult<&mut DeviceTensor>;
266 fn as_device_tensor_mut(&mut self) -> Option<&mut DeviceTensor>;
267}
268
269impl DeviceTensorExt for Tensor {
270 fn to_device_tensor_mut(&mut self) -> TractResult<&mut DeviceTensor> {
271 let opaque = self.to_scalar_mut::<Opaque>()?;
272 opaque.downcast_mut::<DeviceTensor>().ok_or_else(|| {
273 anyhow::anyhow!("Could convert opaque tensor to mutable reference on a device tensor")
274 })
275 }
276
277 fn as_device_tensor_mut(&mut self) -> Option<&mut DeviceTensor> {
278 let opaque = self.to_scalar_mut::<Opaque>().ok()?;
279 opaque.downcast_mut::<DeviceTensor>()
280 }
281
282 fn to_device_tensor(&self) -> TractResult<&DeviceTensor> {
283 let opaque = self.to_scalar::<Opaque>()?;
284 opaque.downcast_ref::<DeviceTensor>().ok_or_else(|| {
285 anyhow::anyhow!("Could convert opaque tensor to reference on a device tensor")
286 })
287 }
288
289 fn as_device_tensor(&self) -> Option<&DeviceTensor> {
290 let opaque = self.to_scalar::<Opaque>().ok()?;
291 opaque.downcast_ref::<DeviceTensor>()
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_device_tensor() -> TractResult<()> {
301 let a = DeviceTensor::from_shape(&[1], &[0f32])?;
302 assert_eq!(a.to_host()?.as_slice::<f32>()?, &[0.0]);
303 Ok(())
304 }
305}