Skip to main content

tract_gpu/tensor/
owned.rs

1use crate::device::{get_context, DeviceBuffer};
2use crate::tensor::DeviceTensor;
3use crate::utils::{as_q40_tensor, check_strides_validity};
4use num_traits::AsPrimitive;
5use std::ffi::c_void;
6use std::fmt::Display;
7use tract_core::internal::*;
8
9#[derive(Debug, Clone, Hash)]
10pub enum DValue {
11    Natural(Arc<Tensor>),
12    Reshaped { t: Arc<Tensor>, shape: TVec<usize>, strides: TVec<isize> },
13}
14
15impl DValue {
16    #[inline]
17    pub fn view(&self) -> TensorView<'_> {
18        match self {
19            Self::Natural(t) => t.view(),
20            Self::Reshaped { t, shape, strides } => unsafe {
21                TensorView::from_bytes(t, 0, shape.as_slice(), strides.as_slice())
22            },
23        }
24    }
25
26    /// Get the datum type of the tensor.
27    #[inline]
28    pub fn datum_type(&self) -> DatumType {
29        match self {
30            Self::Natural(t) => t.datum_type(),
31            Self::Reshaped { t, .. } => t.datum_type(),
32        }
33    }
34
35    #[inline]
36    pub fn shape(&self) -> &[usize] {
37        match self {
38            DValue::Natural(t) => t.shape(),
39            DValue::Reshaped { shape, .. } => shape,
40        }
41    }
42
43    /// Get the number of values.
44    #[inline]
45    #[allow(clippy::len_without_is_empty)]
46    pub fn len(&self) -> usize {
47        self.shape().iter().product()
48    }
49
50    /// Reshaped tensor with given shape.
51    pub fn reshaped(&self, shape: impl Into<TVec<usize>>) -> TractResult<Self> {
52        let shape = shape.into();
53        if self.len() != shape.iter().product::<usize>() {
54            bail!("Invalid reshape {:?} to {:?}", self.shape(), shape);
55        }
56        if shape.as_slice() != self.shape() {
57            match &self {
58                DValue::Natural(t) | DValue::Reshaped { t, .. } => Ok(Self::Reshaped {
59                    t: Arc::clone(t),
60                    strides: Tensor::natural_strides(&shape),
61                    shape,
62                }),
63            }
64        } else {
65            Ok(self.clone())
66        }
67    }
68
69    pub fn restrided(&self, strides: impl Into<TVec<isize>>) -> TractResult<Self> {
70        let strides = strides.into();
71        check_strides_validity(self.shape().into(), strides.clone())?;
72
73        match &self {
74            DValue::Natural(t) => {
75                Ok(Self::Reshaped { t: Arc::clone(t), strides, shape: self.shape().into() })
76            }
77            DValue::Reshaped { t, strides: old_strides, .. } => {
78                if &strides != old_strides {
79                    Ok(Self::Reshaped { t: Arc::clone(t), strides, shape: self.shape().into() })
80                } else {
81                    Ok(self.clone())
82                }
83            }
84        }
85    }
86
87    pub fn as_arc_tensor(&self) -> Option<&Arc<Tensor>> {
88        match self {
89            DValue::Natural(t) => Some(t),
90            DValue::Reshaped { .. } => None,
91        }
92    }
93
94    /// Reshaped tensor with given shape and strides, no consistency check.
95    pub unsafe fn reshaped_with_geometry_unchecked(
96        &self,
97        shape: impl Into<TVec<usize>>,
98        strides: impl Into<TVec<isize>>,
99    ) -> Self {
100        match self {
101            DValue::Natural(t) | DValue::Reshaped { t, .. } => {
102                DValue::Reshaped { t: Arc::clone(t), strides: strides.into(), shape: shape.into() }
103            }
104        }
105    }
106}
107
108impl IntoTensor for DValue {
109    fn into_tensor(self) -> Tensor {
110        match self {
111            Self::Natural(t) => Arc::try_unwrap(t).unwrap_or_else(|t| (*t).clone()),
112            Self::Reshaped { t, shape, strides: _ } => {
113                let mut t = Arc::try_unwrap(t).unwrap_or_else(|t| (*t).clone());
114                t.set_shape(&shape).expect("Could not apply shape to reshaped GPU tensor");
115                t
116            }
117        }
118    }
119}
120
121impl From<Tensor> for DValue {
122    fn from(v: Tensor) -> Self {
123        Self::Natural(Arc::new(v))
124    }
125}
126
127impl From<Arc<Tensor>> for DValue {
128    fn from(v: Arc<Tensor>) -> Self {
129        Self::Natural(v)
130    }
131}
132
133/// This struct represents a owned tensor that can be accessed from the
134/// GPU and the CPU.
135#[derive(Debug, Clone)]
136pub struct OwnedDeviceTensor {
137    pub inner: DValue,
138    pub device_buffer: Box<dyn DeviceBuffer>,
139}
140
141impl Hash for OwnedDeviceTensor {
142    #[inline]
143    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
144        self.inner.hash(state)
145    }
146}
147
148impl OwnedDeviceTensor {
149    /// Create a owned gpu tensor from a cpu tensor.
150    pub fn from_tensor<T: Into<DValue>>(tensor: T) -> TractResult<Self> {
151        let m_value: DValue = tensor.into();
152        let tensor_view = m_value.view();
153        ensure!(
154            DeviceTensor::is_supported_dt(tensor_view.datum_type()),
155            "Tensor of {:?} is not copied. No device buffer can be allocated for it.",
156            tensor_view.datum_type(),
157        );
158
159        let data_bytes = as_q40_tensor(tensor_view.tensor)
160            .map(|bqv| bqv.value.as_bytes())
161            .unwrap_or(tensor_view.tensor.as_bytes());
162
163        let device_buffer = get_context()?.buffer_from_slice(data_bytes);
164
165        Ok(OwnedDeviceTensor { inner: m_value, device_buffer })
166    }
167
168    #[inline]
169    pub fn shape(&self) -> &[usize] {
170        self.inner.shape()
171    }
172
173    /// Get the number of values in the tensor.
174    #[inline]
175    #[allow(clippy::len_without_is_empty)]
176    pub fn len(&self) -> usize {
177        self.shape().iter().product()
178    }
179
180    /// Get the strides of the tensor.
181    #[inline]
182    pub fn strides(&self) -> &[isize] {
183        match &self.inner {
184            DValue::Natural(t) => t.strides(),
185            DValue::Reshaped { strides, .. } => strides,
186        }
187    }
188
189    /// Get underlying inner device buffer.
190    #[inline]
191    pub fn device_buffer(&self) -> &dyn DeviceBuffer {
192        &(*self.device_buffer)
193    }
194
195    pub fn device_buffer_ptr(&self) -> *const c_void {
196        self.device_buffer.ptr()
197    }
198
199    /// Get underlying inner buffer offset
200    #[inline]
201    pub fn buffer_offset<I: Copy + 'static>(&self) -> I
202    where
203        usize: AsPrimitive<I>,
204    {
205        // No offset for non-arena tensor
206        0usize.as_()
207    }
208
209    /// Reshaped tensor with given shape.
210    #[inline]
211    pub fn reshaped(&self, shape: impl Into<TVec<usize>>) -> TractResult<Self> {
212        Ok(Self { inner: self.inner.reshaped(shape)?, device_buffer: self.device_buffer.clone() })
213    }
214
215    /// Change tensor stride.
216    #[inline]
217    pub fn restrided(&self, strides: impl Into<TVec<isize>>) -> TractResult<Self> {
218        Ok(Self {
219            inner: self.inner.restrided(strides)?,
220            device_buffer: self.device_buffer.clone(),
221        })
222    }
223
224    /// Reshaped tensor with given shape and strides, no consistency check.
225    #[inline]
226    pub unsafe fn reshaped_with_geometry_unchecked(
227        &self,
228        shape: impl Into<TVec<usize>>,
229        strides: impl Into<TVec<isize>>,
230    ) -> Self {
231        Self {
232            inner: self.inner.reshaped_with_geometry_unchecked(shape, strides),
233            device_buffer: self.device_buffer.clone(),
234        }
235    }
236
237    #[inline]
238    pub fn view(&self) -> TensorView<'_> {
239        self.inner.view()
240    }
241}
242
243impl Display for OwnedDeviceTensor {
244    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245        match &self.inner {
246            DValue::Natural(t) => {
247                let content = t.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
248                write!(f, "GPU {{ {content} }}")
249            }
250            DValue::Reshaped { t, shape, strides: _ } => {
251                let content = t.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
252                write!(f, "GPU reshaped: {:?} - {{ {content} }}", shape)
253            }
254        }
255    }
256}