Skip to main content

tract_gpu/tensor/
mod.rs

1#![allow(clippy::missing_safety_doc)]
2#![allow(clippy::missing_transmute_annotations)]
3
4mod arena_view;
5mod owned;
6
7pub use arena_view::*;
8pub use owned::*;
9
10use num_traits::AsPrimitive;
11use std::ffi::c_void;
12use std::fmt::Display;
13use tract_core::internal::*;
14use tract_data::itertools::Itertools;
15
16use crate::device::{get_context, DeviceBuffer};
17
18/// This struct represents a GPU tensor that can be either a owned tensor
19/// or an arena view.
20#[derive(Debug, Clone, Hash)]
21pub enum DeviceTensor {
22    Owned(OwnedDeviceTensor),
23    ArenaView(DeviceArenaView),
24}
25
26impl DeviceTensor {
27    pub const SUPPORTED_DT: [DatumType; 12] = [
28        DatumType::Bool,
29        DatumType::F32,
30        DatumType::F16,
31        DatumType::I8,
32        DatumType::U8,
33        DatumType::I16,
34        DatumType::U16,
35        DatumType::I32,
36        DatumType::U32,
37        DatumType::I64,
38        DatumType::U64,
39        DatumType::Opaque,
40    ];
41
42    pub fn tname(dt: DatumType) -> TractResult<&'static str> {
43        Ok(match dt {
44            DatumType::F32 => "f32",
45            DatumType::F16 => "f16",
46            DatumType::U8 => "u8",
47            DatumType::U16 => "u16",
48            DatumType::U32 => "u32",
49            DatumType::U64 => "u64",
50            DatumType::I8 => "i8",
51            DatumType::I16 => "i16",
52            DatumType::I32 => "i32",
53            DatumType::I64 => "i64",
54            DatumType::Bool => "bool",
55            DatumType::Opaque => "opaque",
56            _ => bail!("Unsupport dt {:?} for GPU Tensor", dt),
57        })
58    }
59
60    /// Create an uninitialized DeviceTensor
61    pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<DeviceTensor> {
62        Tensor::uninitialized_dt(dt, shape)?.into_device()
63    }
64
65    pub unsafe fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<DeviceTensor> {
66        Self::uninitialized_dt(T::datum_type(), shape)
67    }
68
69    // Create a device tensor with a given shape and a slice of elements. The data is copied and aligned to size of T.
70    pub fn from_shape<T: Copy + Datum>(shape: &[usize], data: &[T]) -> TractResult<DeviceTensor> {
71        Tensor::from_shape(shape, data)?.into_device()
72    }
73
74    pub fn is_supported_dt(dt: DatumType) -> bool {
75        Self::SUPPORTED_DT.contains(&dt)
76    }
77
78    /// Get the datum type of the tensor.
79    #[inline]
80    pub fn datum_type(&self) -> DatumType {
81        match self {
82            Self::Owned(OwnedDeviceTensor { inner, .. }) => inner.datum_type(),
83            Self::ArenaView(view) => view.datum_type(),
84        }
85    }
86
87    /// Get the number of dimensions (or axes) of the tensor.
88    #[inline]
89    pub fn rank(&self) -> usize {
90        self.shape().len()
91    }
92
93    /// Get the shape of the tensor.
94    #[inline]
95    pub fn shape(&self) -> &[usize] {
96        match self {
97            Self::Owned(t) => t.shape(),
98            Self::ArenaView(t) => t.shape(),
99        }
100    }
101
102    /// Get the number of values in the tensor.
103    #[inline]
104    #[allow(clippy::len_without_is_empty)]
105    pub fn len(&self) -> usize {
106        match self {
107            Self::Owned(t) => t.len(),
108            Self::ArenaView(t) => t.len(),
109        }
110    }
111
112    /// Get the strides of the tensor.
113    #[inline]
114    pub fn strides(&self) -> &[isize] {
115        match self {
116            Self::Owned(t) => t.strides(),
117            Self::ArenaView(t) => t.strides(),
118        }
119    }
120
121    /// Get underlying inner buffer.
122    pub fn device_buffer(&self) -> &dyn DeviceBuffer {
123        match self {
124            Self::Owned(t) => t.device_buffer(),
125            Self::ArenaView(t) => t.device_buffer(),
126        }
127    }
128
129    /// Get underlying inner buffer offset
130    pub fn buffer_offset<I: Copy + 'static>(&self) -> I
131    where
132        usize: AsPrimitive<I>,
133    {
134        match self {
135            Self::Owned(t) => t.buffer_offset(),
136            Self::ArenaView(t) => t.buffer_offset(),
137        }
138    }
139
140    pub fn device_buffer_ptr(&self) -> *const c_void {
141        match self {
142            Self::Owned(t) => t.device_buffer_ptr(),
143            Self::ArenaView(t) => t.device_buffer_ptr(),
144        }
145    }
146
147    /// Get underlying inner tensor view.
148    #[inline]
149    pub fn view(&self) -> TensorView {
150        match self {
151            Self::Owned(t) => t.view(),
152            Self::ArenaView(t) => t.view(),
153        }
154    }
155
156    /// Returns short description of the inner tensor.
157    pub fn description(&self) -> String {
158        format!("|{},{:?}|", self.shape().iter().join(","), self.datum_type(),)
159    }
160
161    /// Reshaped tensor with given shape.
162    pub fn reshaped(&self, shape: impl Into<TVec<usize>>) -> TractResult<Self> {
163        match self {
164            Self::Owned(t) => Ok(Self::Owned(t.reshaped(shape)?)),
165            Self::ArenaView(t) => Ok(Self::ArenaView(t.reshaped(shape)?)),
166        }
167    }
168
169    pub fn restrided(&self, strides: impl Into<TVec<isize>>) -> TractResult<Self> {
170        match self {
171            Self::Owned(t) => Ok(Self::Owned(t.restrided(strides)?)),
172            Self::ArenaView(t) => Ok(Self::ArenaView(t.restrided(strides)?)),
173        }
174    }
175
176    /// Convert device tensor to Opaque Tensor.
177    pub fn into_opaque_tensor(self) -> Tensor {
178        tensor0::<Opaque>(self.into())
179    }
180
181    /// Synchronize the GPU Tensor by completing all current
182    /// commands on GPU and returns the inner tensor.
183    pub fn to_host(&self) -> TractResult<Arc<Tensor>> {
184        get_context()?.synchronize()?;
185
186        Ok(match self {
187            Self::Owned(o) => o
188                .inner
189                .as_arc_tensor()
190                .cloned()
191                .unwrap_or_else(|| o.inner.clone().into_tensor().into_arc_tensor()),
192            Self::ArenaView(v) => v.clone().into_tensor().into(),
193        })
194    }
195}
196
197impl Display for DeviceTensor {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        match self {
200            Self::Owned(o) => match &o.inner {
201                DValue::Natural(t) => {
202                    let content = t.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
203                    write!(f, "Owned: {{ {content} }}")
204                }
205                DValue::Reshaped { t, shape, .. } => {
206                    let content = t.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
207                    write!(f, "Owned,Reshaped: {:?} - {{ {content} }}", shape)
208                }
209            },
210            Self::ArenaView(v) => {
211                let content = v
212                    .clone()
213                    .into_tensor()
214                    .dump(false)
215                    .unwrap_or_else(|e| format!("Error : {e:?}"));
216                write!(f, "ArenaView: {{ {content} }}")
217            }
218        }
219    }
220}
221
222pub trait IntoDevice<T> {
223    fn into_device(self) -> TractResult<T>;
224}
225
226impl IntoDevice<DeviceTensor> for Tensor {
227    fn into_device(self) -> TractResult<DeviceTensor> {
228        Ok(DeviceTensor::Owned(OwnedDeviceTensor::from_tensor(self)?))
229    }
230}
231
232impl IntoDevice<DeviceTensor> for Arc<Tensor> {
233    fn into_device(self) -> TractResult<DeviceTensor> {
234        Ok(DeviceTensor::Owned(OwnedDeviceTensor::from_tensor(self)?))
235    }
236}
237
238impl From<DeviceTensor> for Opaque {
239    fn from(value: DeviceTensor) -> Self {
240        Opaque(Arc::new(value))
241    }
242}
243
244impl From<DeviceArenaView> for DeviceTensor {
245    fn from(view: DeviceArenaView) -> Self {
246        Self::ArenaView(view)
247    }
248}
249
250impl OpaquePayload for DeviceTensor {
251    fn same_as(&self, other: &dyn OpaquePayload) -> bool {
252        other
253            .downcast_ref::<Self>()
254            .is_some_and(|other| self.device_buffer_ptr() == other.device_buffer_ptr())
255    }
256
257    fn clarify_to_tensor(&self) -> TractResult<Option<Arc<Tensor>>> {
258        Ok(Some(self.to_host()?))
259    }
260}
261
262pub trait DeviceTensorExt {
263    fn to_device_tensor(&self) -> TractResult<&DeviceTensor>;
264    fn as_device_tensor(&self) -> Option<&DeviceTensor>;
265    fn to_device_tensor_mut(&mut self) -> TractResult<&mut DeviceTensor>;
266    fn as_device_tensor_mut(&mut self) -> Option<&mut DeviceTensor>;
267}
268
269impl DeviceTensorExt for Tensor {
270    fn to_device_tensor_mut(&mut self) -> TractResult<&mut DeviceTensor> {
271        let opaque = self.to_scalar_mut::<Opaque>()?;
272        opaque.downcast_mut::<DeviceTensor>().ok_or_else(|| {
273            anyhow::anyhow!("Could convert opaque tensor to mutable reference on a device tensor")
274        })
275    }
276
277    fn as_device_tensor_mut(&mut self) -> Option<&mut DeviceTensor> {
278        let opaque = self.to_scalar_mut::<Opaque>().ok()?;
279        opaque.downcast_mut::<DeviceTensor>()
280    }
281
282    fn to_device_tensor(&self) -> TractResult<&DeviceTensor> {
283        let opaque = self.to_scalar::<Opaque>()?;
284        opaque.downcast_ref::<DeviceTensor>().ok_or_else(|| {
285            anyhow::anyhow!("Could convert opaque tensor to reference on a device tensor")
286        })
287    }
288
289    fn as_device_tensor(&self) -> Option<&DeviceTensor> {
290        let opaque = self.to_scalar::<Opaque>().ok()?;
291        opaque.downcast_ref::<DeviceTensor>()
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn test_device_tensor() -> TractResult<()> {
301        let a = DeviceTensor::from_shape(&[1], &[0f32])?;
302        assert_eq!(a.to_host()?.as_slice::<f32>()?, &[0.0]);
303        Ok(())
304    }
305}