Skip to main content

tract_gpu/tensor/
mod.rs

1#![allow(clippy::missing_safety_doc)]
2#![allow(clippy::missing_transmute_annotations)]
3
4mod arena_view;
5mod owned;
6
7pub use arena_view::*;
8pub use owned::*;
9
10use num_traits::AsPrimitive;
11use std::ffi::c_void;
12use std::fmt::Display;
13use tract_core::internal::*;
14use tract_data::itertools::Itertools;
15
16use crate::device::{DeviceBuffer, get_context};
17
18/// This struct represents a GPU tensor that can be either a owned tensor
19/// or an arena view.
20#[derive(Debug, Clone, Hash, PartialEq, Eq)]
21pub enum DeviceTensor {
22    Owned(Box<dyn OwnedDeviceTensor>),
23    ArenaView(DeviceArenaView),
24}
25
26impl DeviceTensor {
27    pub const SUPPORTED_DT: [DatumType; 11] = [
28        DatumType::Bool,
29        DatumType::F32,
30        DatumType::F16,
31        DatumType::I8,
32        DatumType::U8,
33        DatumType::I16,
34        DatumType::U16,
35        DatumType::I32,
36        DatumType::U32,
37        DatumType::I64,
38        DatumType::U64,
39    ];
40
41    pub fn tname(dt: DatumType) -> TractResult<&'static str> {
42        Ok(match dt {
43            DatumType::F32 => "f32",
44            DatumType::F16 => "f16",
45            DatumType::U8 => "u8",
46            DatumType::U16 => "u16",
47            DatumType::U32 => "u32",
48            DatumType::U64 => "u64",
49            DatumType::I8 => "i8",
50            DatumType::I16 => "i16",
51            DatumType::I32 => "i32",
52            DatumType::I64 => "i64",
53            DatumType::Bool => "bool",
54            _ => bail!("Unsupported dt {:?} for GPU Tensor", dt),
55        })
56    }
57
58    /// Create an uninitialized DeviceTensor
59    pub fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<DeviceTensor> {
60        Ok(DeviceTensor::Owned(get_context()?.uninitialized_device_tensor(shape, dt)?))
61    }
62
63    pub fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<DeviceTensor> {
64        Self::uninitialized_dt(T::datum_type(), shape)
65    }
66
67    pub fn uninitialized_exotic(exotic_fact: Box<dyn ExoticFact>) -> TractResult<DeviceTensor> {
68        Ok(DeviceTensor::Owned(get_context()?.uninitialized_device_exotic_tensor(exotic_fact)?))
69    }
70    // Create a device tensor with a given shape and a slice of elements. The data is copied and aligned to size of T.
71    pub fn from_shape<T: Copy + Datum>(shape: &[usize], data: &[T]) -> TractResult<DeviceTensor> {
72        Tensor::from_shape(shape, data)?.into_device()
73    }
74
75    pub fn is_supported_dt(dt: DatumType) -> bool {
76        Self::SUPPORTED_DT.contains(&dt)
77    }
78
79    /// Get the datum type of the tensor.
80    #[inline]
81    pub fn datum_type(&self) -> DatumType {
82        match self {
83            Self::Owned(owned) => owned.datum_type(),
84            Self::ArenaView(view) => view.datum_type(),
85        }
86    }
87
88    /// Get the number of dimensions (or axes) of the tensor.
89    #[inline]
90    pub fn rank(&self) -> usize {
91        self.shape().len()
92    }
93
94    /// Get the shape of the tensor.
95    #[inline]
96    pub fn shape(&self) -> &[usize] {
97        match self {
98            Self::Owned(t) => t.shape(),
99            Self::ArenaView(t) => t.shape(),
100        }
101    }
102
103    /// Get the number of values in the tensor.
104    #[inline]
105    #[allow(clippy::len_without_is_empty)]
106    pub fn len(&self) -> usize {
107        match self {
108            Self::Owned(t) => t.len(),
109            Self::ArenaView(t) => t.len(),
110        }
111    }
112
113    /// Get the strides of the tensor.
114    #[inline]
115    pub fn strides(&self) -> &[isize] {
116        match self {
117            Self::Owned(t) => t.strides(),
118            Self::ArenaView(t) => t.strides(),
119        }
120    }
121
122    /// Get underlying inner buffer.
123    pub fn device_buffer(&self) -> &dyn DeviceBuffer {
124        match self {
125            Self::Owned(t) => t.device_buffer(),
126            Self::ArenaView(t) => t.device_buffer(),
127        }
128    }
129
130    /// Get underlying inner buffer offset
131    pub fn buffer_offset<I: Copy + 'static>(&self) -> I
132    where
133        usize: AsPrimitive<I>,
134    {
135        match self {
136            Self::Owned(_) => 0.as_(),
137            Self::ArenaView(t) => t.buffer_offset(),
138        }
139    }
140
141    pub fn device_buffer_ptr(&self) -> *const c_void {
142        match self {
143            Self::Owned(t) => t.device_buffer().ptr(),
144            Self::ArenaView(t) => t.device_buffer().ptr(),
145        }
146    }
147
148    /// Returns short description of the inner tensor.
149    pub fn description(&self) -> String {
150        format!("|{},{:?}|", self.shape().iter().join(","), self.datum_type(),)
151    }
152
153    /// Reshaped tensor with given shape.
154    pub fn reshaped(&self, shape: TVec<usize>) -> TractResult<Self> {
155        match self {
156            Self::Owned(t) => Ok(t.reshaped(shape)?),
157            Self::ArenaView(t) => Ok(Self::ArenaView(t.reshaped(shape)?)),
158        }
159    }
160
161    pub fn restrided(&self, strides: TVec<isize>) -> TractResult<Self> {
162        match self {
163            Self::Owned(t) => Ok(t.restrided(strides)?),
164            Self::ArenaView(t) => Ok(Self::ArenaView(t.restrided(strides)?)),
165        }
166    }
167
168    /// Convert device tensor to a Tensor backed by device storage.
169    ///
170    /// The resulting tensor carries the real datum type and shape from the
171    /// device tensor (e.g. F32 / \[2,3\]), rather than an exotic scalar wrapper.
172    pub fn into_tensor(self) -> Tensor {
173        let dt = self.datum_type();
174        let shape: TVec<usize> = self.shape().into();
175        Tensor::from_storage(dt, &shape, self)
176    }
177
178    /// Synchronize the GPU Tensor by completing all current
179    /// commands on GPU and returns the inner tensor.
180    pub fn to_host(&self) -> TractResult<Arc<Tensor>> {
181        get_context()?.synchronize()?;
182
183        Ok(match self {
184            Self::Owned(o) => o.to_host()?,
185            Self::ArenaView(v) => v.to_host()?.into(),
186        })
187    }
188}
189
190impl Display for DeviceTensor {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        match self {
193            Self::Owned(o) => o.fmt(f),
194            Self::ArenaView(v) => {
195                let content =
196                    v.to_host().unwrap().dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
197                write!(f, "ArenaView: {{ {content} }}")
198            }
199        }
200    }
201}
202
203pub trait IntoDevice<T> {
204    fn into_device(self) -> TractResult<T>;
205}
206
207impl IntoDevice<DeviceTensor> for Tensor {
208    fn into_device(self) -> TractResult<DeviceTensor> {
209        Ok(DeviceTensor::Owned(get_context()?.tensor_to_device(self.into_tvalue())?))
210    }
211}
212
213impl IntoDevice<DeviceTensor> for Arc<Tensor> {
214    fn into_device(self) -> TractResult<DeviceTensor> {
215        Ok(DeviceTensor::Owned(get_context()?.tensor_to_device(self.into_tvalue())?))
216    }
217}
218
219impl TensorStorage for DeviceTensor {
220    fn byte_len(&self) -> usize {
221        self.len() * self.datum_type().size_of()
222    }
223
224    fn is_empty(&self) -> bool {
225        self.byte_len() == 0
226    }
227
228    fn deep_clone(&self) -> Box<dyn TensorStorage> {
229        Box::new(self.clone())
230    }
231
232    fn as_plain(&self) -> Option<&PlainStorage> {
233        None
234    }
235
236    fn as_plain_mut(&mut self) -> Option<&mut PlainStorage> {
237        None
238    }
239
240    fn into_plain(self: Box<Self>) -> Option<PlainStorage> {
241        None
242    }
243
244    fn dyn_hash(&self, _state: &mut dyn std::hash::Hasher) {
245        // no meaningful hash for device memory
246    }
247
248    fn exotic_fact(&self, _shape: &[usize]) -> TractResult<Option<Box<dyn ExoticFact>>> {
249        bail!(
250            "DeviceTensor cannot reconstruct a DeviceFact: origin (FromHost/FromDevice) is not carried by storage"
251        )
252    }
253}
254
255impl From<DeviceArenaView> for DeviceTensor {
256    fn from(view: DeviceArenaView) -> Self {
257        Self::ArenaView(view)
258    }
259}
260
261pub trait DeviceTensorExt {
262    fn to_device_tensor(&self) -> TractResult<&DeviceTensor>;
263    fn as_device_tensor(&self) -> Option<&DeviceTensor>;
264}
265
266impl DeviceTensorExt for Tensor {
267    fn to_device_tensor(&self) -> TractResult<&DeviceTensor> {
268        self.try_storage_as::<DeviceTensor>()
269    }
270
271    fn as_device_tensor(&self) -> Option<&DeviceTensor> {
272        self.storage_as::<DeviceTensor>()
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_device_tensor() -> TractResult<()> {
282        let a = DeviceTensor::from_shape(&[1], &[0f32])?;
283        assert_eq!(a.to_host()?.try_as_plain()?.as_slice::<f32>()?, &[0.0]);
284        Ok(())
285    }
286}