Skip to main content

tract_gpu/tensor/
mod.rs

1#![allow(clippy::missing_safety_doc)]
2#![allow(clippy::missing_transmute_annotations)]
3
4mod arena_view;
5mod owned;
6
7pub use arena_view::*;
8pub use owned::*;
9
10use num_traits::AsPrimitive;
11use std::ffi::c_void;
12use std::fmt::Display;
13use tract_core::internal::*;
14use tract_data::itertools::Itertools;
15
16use crate::device::{DeviceBuffer, get_context};
17
18/// This struct represents a GPU tensor that can be either a owned tensor
19/// or an arena view.
20#[derive(Debug, Clone, Hash)]
21pub enum DeviceTensor {
22    Owned(Box<dyn OwnedDeviceTensor>),
23    ArenaView(DeviceArenaView),
24}
25
26impl DeviceTensor {
27    pub const SUPPORTED_DT: [DatumType; 12] = [
28        DatumType::Bool,
29        DatumType::F32,
30        DatumType::F16,
31        DatumType::I8,
32        DatumType::U8,
33        DatumType::I16,
34        DatumType::U16,
35        DatumType::I32,
36        DatumType::U32,
37        DatumType::I64,
38        DatumType::U64,
39        DatumType::Opaque,
40    ];
41
42    pub fn tname(dt: DatumType) -> TractResult<&'static str> {
43        Ok(match dt {
44            DatumType::F32 => "f32",
45            DatumType::F16 => "f16",
46            DatumType::U8 => "u8",
47            DatumType::U16 => "u16",
48            DatumType::U32 => "u32",
49            DatumType::U64 => "u64",
50            DatumType::I8 => "i8",
51            DatumType::I16 => "i16",
52            DatumType::I32 => "i32",
53            DatumType::I64 => "i64",
54            DatumType::Bool => "bool",
55            DatumType::Opaque => "opaque",
56            _ => bail!("Unsupported dt {:?} for GPU Tensor", dt),
57        })
58    }
59
60    /// Create an uninitialized DeviceTensor
61    pub fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<DeviceTensor> {
62        Ok(DeviceTensor::Owned(get_context()?.uninitialized_device_tensor(shape, dt)?))
63    }
64
65    pub fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<DeviceTensor> {
66        Self::uninitialized_dt(T::datum_type(), shape)
67    }
68
69    pub fn uninitialized_opaque(opaque_fact: Box<dyn OpaqueFact>) -> TractResult<DeviceTensor> {
70        Ok(DeviceTensor::Owned(get_context()?.uninitialized_device_opaque_tensor(opaque_fact)?))
71    }
72    // Create a device tensor with a given shape and a slice of elements. The data is copied and aligned to size of T.
73    pub fn from_shape<T: Copy + Datum>(shape: &[usize], data: &[T]) -> TractResult<DeviceTensor> {
74        Tensor::from_shape(shape, data)?.into_device()
75    }
76
77    pub fn is_supported_dt(dt: DatumType) -> bool {
78        Self::SUPPORTED_DT.contains(&dt)
79    }
80
81    /// Get the datum type of the tensor.
82    #[inline]
83    pub fn datum_type(&self) -> DatumType {
84        match self {
85            Self::Owned(owned) => owned.datum_type(),
86            Self::ArenaView(view) => view.datum_type(),
87        }
88    }
89
90    /// Get the number of dimensions (or axes) of the tensor.
91    #[inline]
92    pub fn rank(&self) -> usize {
93        self.shape().len()
94    }
95
96    /// Get the shape of the tensor.
97    #[inline]
98    pub fn shape(&self) -> &[usize] {
99        match self {
100            Self::Owned(t) => t.shape(),
101            Self::ArenaView(t) => t.shape(),
102        }
103    }
104
105    /// Get the number of values in the tensor.
106    #[inline]
107    #[allow(clippy::len_without_is_empty)]
108    pub fn len(&self) -> usize {
109        match self {
110            Self::Owned(t) => t.len(),
111            Self::ArenaView(t) => t.len(),
112        }
113    }
114
115    /// Get the strides of the tensor.
116    #[inline]
117    pub fn strides(&self) -> &[isize] {
118        match self {
119            Self::Owned(t) => t.strides(),
120            Self::ArenaView(t) => t.strides(),
121        }
122    }
123
124    /// Get underlying inner buffer.
125    pub fn device_buffer(&self) -> &dyn DeviceBuffer {
126        match self {
127            Self::Owned(t) => t.device_buffer(),
128            Self::ArenaView(t) => t.device_buffer(),
129        }
130    }
131
132    /// Get underlying inner buffer offset
133    pub fn buffer_offset<I: Copy + 'static>(&self) -> I
134    where
135        usize: AsPrimitive<I>,
136    {
137        match self {
138            Self::Owned(_) => 0.as_(),
139            Self::ArenaView(t) => t.buffer_offset(),
140        }
141    }
142
143    pub fn device_buffer_ptr(&self) -> *const c_void {
144        match self {
145            Self::Owned(t) => t.device_buffer().ptr(),
146            Self::ArenaView(t) => t.device_buffer().ptr(),
147        }
148    }
149
150    /// Returns short description of the inner tensor.
151    pub fn description(&self) -> String {
152        format!("|{},{:?}|", self.shape().iter().join(","), self.datum_type(),)
153    }
154
155    /// Reshaped tensor with given shape.
156    pub fn reshaped(&self, shape: TVec<usize>) -> TractResult<Self> {
157        match self {
158            Self::Owned(t) => Ok(t.reshaped(shape)?),
159            Self::ArenaView(t) => Ok(Self::ArenaView(t.reshaped(shape)?)),
160        }
161    }
162
163    pub fn restrided(&self, strides: TVec<isize>) -> TractResult<Self> {
164        match self {
165            Self::Owned(t) => Ok(t.restrided(strides)?),
166            Self::ArenaView(t) => Ok(Self::ArenaView(t.restrided(strides)?)),
167        }
168    }
169
170    /// Convert device tensor to Opaque Tensor.
171    pub fn into_opaque_tensor(self) -> Tensor {
172        tensor0::<Opaque>(self.into())
173    }
174
175    /// Synchronize the GPU Tensor by completing all current
176    /// commands on GPU and returns the inner tensor.
177    pub fn to_host(&self) -> TractResult<Arc<Tensor>> {
178        get_context()?.synchronize()?;
179
180        Ok(match self {
181            Self::Owned(o) => o.to_host()?,
182            Self::ArenaView(v) => v.to_host()?.into(),
183        })
184    }
185}
186
187impl Display for DeviceTensor {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        match self {
190            Self::Owned(o) => o.fmt(f),
191            Self::ArenaView(v) => {
192                let content =
193                    v.to_host().unwrap().dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
194                write!(f, "ArenaView: {{ {content} }}")
195            }
196        }
197    }
198}
199
200pub trait IntoDevice<T> {
201    fn into_device(self) -> TractResult<T>;
202}
203
204impl IntoDevice<DeviceTensor> for Tensor {
205    fn into_device(self) -> TractResult<DeviceTensor> {
206        Ok(DeviceTensor::Owned(get_context()?.tensor_to_device(self.into_tvalue())?))
207    }
208}
209
210impl IntoDevice<DeviceTensor> for Arc<Tensor> {
211    fn into_device(self) -> TractResult<DeviceTensor> {
212        Ok(DeviceTensor::Owned(get_context()?.tensor_to_device(self.into_tvalue())?))
213    }
214}
215
216impl From<DeviceTensor> for Opaque {
217    fn from(value: DeviceTensor) -> Self {
218        Opaque(Arc::new(value))
219    }
220}
221
222impl From<DeviceArenaView> for DeviceTensor {
223    fn from(view: DeviceArenaView) -> Self {
224        Self::ArenaView(view)
225    }
226}
227
228impl OpaquePayload for DeviceTensor {
229    fn same_as(&self, other: &dyn OpaquePayload) -> bool {
230        other
231            .downcast_ref::<Self>()
232            .is_some_and(|other| self.device_buffer_ptr() == other.device_buffer_ptr())
233    }
234
235    fn clarify_to_tensor(&self) -> TractResult<Option<Arc<Tensor>>> {
236        Ok(Some(self.to_host()?))
237    }
238}
239
240pub trait DeviceTensorExt {
241    fn to_device_tensor(&self) -> TractResult<&DeviceTensor>;
242    fn as_device_tensor(&self) -> Option<&DeviceTensor>;
243    fn to_device_tensor_mut(&mut self) -> TractResult<&mut DeviceTensor>;
244    fn as_device_tensor_mut(&mut self) -> Option<&mut DeviceTensor>;
245}
246
247impl DeviceTensorExt for Tensor {
248    fn to_device_tensor_mut(&mut self) -> TractResult<&mut DeviceTensor> {
249        let opaque = self.to_scalar_mut::<Opaque>()?;
250        opaque.downcast_mut::<DeviceTensor>().ok_or_else(|| {
251            anyhow::anyhow!("Could convert opaque tensor to mutable reference on a device tensor")
252        })
253    }
254
255    fn as_device_tensor_mut(&mut self) -> Option<&mut DeviceTensor> {
256        let opaque = self.to_scalar_mut::<Opaque>().ok()?;
257        opaque.downcast_mut::<DeviceTensor>()
258    }
259
260    fn to_device_tensor(&self) -> TractResult<&DeviceTensor> {
261        let opaque = self.to_scalar::<Opaque>()?;
262        opaque.downcast_ref::<DeviceTensor>().ok_or_else(|| {
263            anyhow::anyhow!("Could convert opaque tensor to reference on a device tensor")
264        })
265    }
266
267    fn as_device_tensor(&self) -> Option<&DeviceTensor> {
268        let opaque = self.to_scalar::<Opaque>().ok()?;
269        opaque.downcast_ref::<DeviceTensor>()
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_device_tensor() -> TractResult<()> {
279        let a = DeviceTensor::from_shape(&[1], &[0f32])?;
280        assert_eq!(a.to_host()?.as_slice::<f32>()?, &[0.0]);
281        Ok(())
282    }
283}