1#![allow(clippy::missing_safety_doc)]
2#![allow(clippy::missing_transmute_annotations)]
3
4mod arena_view;
5mod owned;
6
7pub use arena_view::*;
8pub use owned::*;
9
10use num_traits::AsPrimitive;
11use std::ffi::c_void;
12use std::fmt::Display;
13use tract_core::internal::*;
14use tract_data::itertools::Itertools;
15
16use crate::device::{DeviceBuffer, get_context};
17
18#[derive(Debug, Clone, Hash)]
21pub enum DeviceTensor {
22 Owned(Box<dyn OwnedDeviceTensor>),
23 ArenaView(DeviceArenaView),
24}
25
26impl DeviceTensor {
27 pub const SUPPORTED_DT: [DatumType; 12] = [
28 DatumType::Bool,
29 DatumType::F32,
30 DatumType::F16,
31 DatumType::I8,
32 DatumType::U8,
33 DatumType::I16,
34 DatumType::U16,
35 DatumType::I32,
36 DatumType::U32,
37 DatumType::I64,
38 DatumType::U64,
39 DatumType::Opaque,
40 ];
41
42 pub fn tname(dt: DatumType) -> TractResult<&'static str> {
43 Ok(match dt {
44 DatumType::F32 => "f32",
45 DatumType::F16 => "f16",
46 DatumType::U8 => "u8",
47 DatumType::U16 => "u16",
48 DatumType::U32 => "u32",
49 DatumType::U64 => "u64",
50 DatumType::I8 => "i8",
51 DatumType::I16 => "i16",
52 DatumType::I32 => "i32",
53 DatumType::I64 => "i64",
54 DatumType::Bool => "bool",
55 DatumType::Opaque => "opaque",
56 _ => bail!("Unsupported dt {:?} for GPU Tensor", dt),
57 })
58 }
59
60 pub fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<DeviceTensor> {
62 Ok(DeviceTensor::Owned(get_context()?.uninitialized_device_tensor(shape, dt)?))
63 }
64
65 pub fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<DeviceTensor> {
66 Self::uninitialized_dt(T::datum_type(), shape)
67 }
68
69 pub fn from_shape<T: Copy + Datum>(shape: &[usize], data: &[T]) -> TractResult<DeviceTensor> {
71 Tensor::from_shape(shape, data)?.into_device()
72 }
73
74 pub fn is_supported_dt(dt: DatumType) -> bool {
75 Self::SUPPORTED_DT.contains(&dt)
76 }
77
78 #[inline]
80 pub fn datum_type(&self) -> DatumType {
81 match self {
82 Self::Owned(owned) => owned.datum_type(),
83 Self::ArenaView(view) => view.datum_type(),
84 }
85 }
86
87 #[inline]
89 pub fn rank(&self) -> usize {
90 self.shape().len()
91 }
92
93 #[inline]
95 pub fn shape(&self) -> &[usize] {
96 match self {
97 Self::Owned(t) => t.shape(),
98 Self::ArenaView(t) => t.shape(),
99 }
100 }
101
102 #[inline]
104 #[allow(clippy::len_without_is_empty)]
105 pub fn len(&self) -> usize {
106 match self {
107 Self::Owned(t) => t.len(),
108 Self::ArenaView(t) => t.len(),
109 }
110 }
111
112 #[inline]
114 pub fn strides(&self) -> &[isize] {
115 match self {
116 Self::Owned(t) => t.strides(),
117 Self::ArenaView(t) => t.strides(),
118 }
119 }
120
121 pub fn device_buffer(&self) -> &dyn DeviceBuffer {
123 match self {
124 Self::Owned(t) => t.device_buffer(),
125 Self::ArenaView(t) => t.device_buffer(),
126 }
127 }
128
129 pub fn buffer_offset<I: Copy + 'static>(&self) -> I
131 where
132 usize: AsPrimitive<I>,
133 {
134 match self {
135 Self::Owned(_) => 0.as_(),
136 Self::ArenaView(t) => t.buffer_offset(),
137 }
138 }
139
140 pub fn device_buffer_ptr(&self) -> *const c_void {
141 match self {
142 Self::Owned(t) => t.device_buffer().ptr(),
143 Self::ArenaView(t) => t.device_buffer().ptr(),
144 }
145 }
146
147 #[inline]
149 pub fn view(&self) -> TensorView<'_> {
150 match self {
151 Self::Owned(t) => t.view(),
152 Self::ArenaView(t) => t.view(),
153 }
154 }
155
156 pub fn description(&self) -> String {
158 format!("|{},{:?}|", self.shape().iter().join(","), self.datum_type(),)
159 }
160
161 pub fn reshaped(&self, shape: TVec<usize>) -> TractResult<Self> {
163 match self {
164 Self::Owned(t) => Ok(t.reshaped(shape)?),
165 Self::ArenaView(t) => Ok(Self::ArenaView(t.reshaped(shape)?)),
166 }
167 }
168
169 pub fn restrided(&self, strides: TVec<isize>) -> TractResult<Self> {
170 match self {
171 Self::Owned(t) => Ok(t.restrided(strides)?),
172 Self::ArenaView(t) => Ok(Self::ArenaView(t.restrided(strides)?)),
173 }
174 }
175
176 pub fn into_opaque_tensor(self) -> Tensor {
178 tensor0::<Opaque>(self.into())
179 }
180
181 pub fn to_host(&self) -> TractResult<Arc<Tensor>> {
184 get_context()?.synchronize()?;
185
186 Ok(match self {
187 Self::Owned(o) => o.to_host()?,
188 Self::ArenaView(v) => v.clone().into_tensor().into(),
189 })
190 }
191}
192
193impl Display for DeviceTensor {
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 match self {
196 Self::Owned(o) => o.fmt(f),
197 Self::ArenaView(v) => {
198 let content = v
199 .clone()
200 .into_tensor()
201 .dump(false)
202 .unwrap_or_else(|e| format!("Error : {e:?}"));
203 write!(f, "ArenaView: {{ {content} }}")
204 }
205 }
206 }
207}
208
209pub trait IntoDevice<T> {
210 fn into_device(self) -> TractResult<T>;
211}
212
213impl IntoDevice<DeviceTensor> for Tensor {
214 fn into_device(self) -> TractResult<DeviceTensor> {
215 Ok(DeviceTensor::Owned(get_context()?.tensor_to_device(self.into_tvalue())?))
216 }
217}
218
219impl IntoDevice<DeviceTensor> for Arc<Tensor> {
220 fn into_device(self) -> TractResult<DeviceTensor> {
221 Ok(DeviceTensor::Owned(get_context()?.tensor_to_device(self.into_tvalue())?))
222 }
223}
224
225impl From<DeviceTensor> for Opaque {
226 fn from(value: DeviceTensor) -> Self {
227 Opaque(Arc::new(value))
228 }
229}
230
231impl From<DeviceArenaView> for DeviceTensor {
232 fn from(view: DeviceArenaView) -> Self {
233 Self::ArenaView(view)
234 }
235}
236
237impl OpaquePayload for DeviceTensor {
238 fn same_as(&self, other: &dyn OpaquePayload) -> bool {
239 other
240 .downcast_ref::<Self>()
241 .is_some_and(|other| self.device_buffer_ptr() == other.device_buffer_ptr())
242 }
243
244 fn clarify_to_tensor(&self) -> TractResult<Option<Arc<Tensor>>> {
245 Ok(Some(self.to_host()?))
246 }
247}
248
249pub trait DeviceTensorExt {
250 fn to_device_tensor(&self) -> TractResult<&DeviceTensor>;
251 fn as_device_tensor(&self) -> Option<&DeviceTensor>;
252 fn to_device_tensor_mut(&mut self) -> TractResult<&mut DeviceTensor>;
253 fn as_device_tensor_mut(&mut self) -> Option<&mut DeviceTensor>;
254}
255
256impl DeviceTensorExt for Tensor {
257 fn to_device_tensor_mut(&mut self) -> TractResult<&mut DeviceTensor> {
258 let opaque = self.to_scalar_mut::<Opaque>()?;
259 opaque.downcast_mut::<DeviceTensor>().ok_or_else(|| {
260 anyhow::anyhow!("Could convert opaque tensor to mutable reference on a device tensor")
261 })
262 }
263
264 fn as_device_tensor_mut(&mut self) -> Option<&mut DeviceTensor> {
265 let opaque = self.to_scalar_mut::<Opaque>().ok()?;
266 opaque.downcast_mut::<DeviceTensor>()
267 }
268
269 fn to_device_tensor(&self) -> TractResult<&DeviceTensor> {
270 let opaque = self.to_scalar::<Opaque>()?;
271 opaque.downcast_ref::<DeviceTensor>().ok_or_else(|| {
272 anyhow::anyhow!("Could convert opaque tensor to reference on a device tensor")
273 })
274 }
275
276 fn as_device_tensor(&self) -> Option<&DeviceTensor> {
277 let opaque = self.to_scalar::<Opaque>().ok()?;
278 opaque.downcast_ref::<DeviceTensor>()
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[test]
287 fn test_device_tensor() -> TractResult<()> {
288 let a = DeviceTensor::from_shape(&[1], &[0f32])?;
289 assert_eq!(a.to_host()?.as_slice::<f32>()?, &[0.0]);
290 Ok(())
291 }
292}