tract_gpu/tensor/
owned.rs1use crate::device::{get_context, DeviceBuffer};
2use crate::tensor::DeviceTensor;
3use crate::utils::{as_q40_tensor, check_strides_validity};
4use num_traits::AsPrimitive;
5use std::ffi::c_void;
6use std::fmt::Display;
7use tract_core::internal::*;
8
9#[derive(Debug, Clone, Hash)]
10pub enum DValue {
11 Natural(Arc<Tensor>),
12 Reshaped { t: Arc<Tensor>, shape: TVec<usize>, strides: TVec<isize> },
13}
14
15impl DValue {
16 #[inline]
17 pub fn view(&self) -> TensorView<'_> {
18 match self {
19 Self::Natural(t) => t.view(),
20 Self::Reshaped { t, shape, strides } => unsafe {
21 TensorView::from_bytes(t, 0, shape.as_slice(), strides.as_slice())
22 },
23 }
24 }
25
26 #[inline]
28 pub fn datum_type(&self) -> DatumType {
29 match self {
30 Self::Natural(t) => t.datum_type(),
31 Self::Reshaped { t, .. } => t.datum_type(),
32 }
33 }
34
35 #[inline]
36 pub fn shape(&self) -> &[usize] {
37 match self {
38 DValue::Natural(t) => t.shape(),
39 DValue::Reshaped { shape, .. } => shape,
40 }
41 }
42
43 #[inline]
45 #[allow(clippy::len_without_is_empty)]
46 pub fn len(&self) -> usize {
47 self.shape().iter().product()
48 }
49
50 pub fn reshaped(&self, shape: impl Into<TVec<usize>>) -> TractResult<Self> {
52 let shape = shape.into();
53 if self.len() != shape.iter().product::<usize>() {
54 bail!("Invalid reshape {:?} to {:?}", self.shape(), shape);
55 }
56 if shape.as_slice() != self.shape() {
57 match &self {
58 DValue::Natural(t) | DValue::Reshaped { t, .. } => Ok(Self::Reshaped {
59 t: Arc::clone(t),
60 strides: Tensor::natural_strides(&shape),
61 shape,
62 }),
63 }
64 } else {
65 Ok(self.clone())
66 }
67 }
68
69 pub fn restrided(&self, strides: impl Into<TVec<isize>>) -> TractResult<Self> {
70 let strides = strides.into();
71 check_strides_validity(self.shape().into(), strides.clone())?;
72
73 match &self {
74 DValue::Natural(t) => {
75 Ok(Self::Reshaped { t: Arc::clone(t), strides, shape: self.shape().into() })
76 }
77 DValue::Reshaped { t, strides: old_strides, .. } => {
78 if &strides != old_strides {
79 Ok(Self::Reshaped { t: Arc::clone(t), strides, shape: self.shape().into() })
80 } else {
81 Ok(self.clone())
82 }
83 }
84 }
85 }
86
87 pub fn as_arc_tensor(&self) -> Option<&Arc<Tensor>> {
88 match self {
89 DValue::Natural(t) => Some(t),
90 DValue::Reshaped { .. } => None,
91 }
92 }
93
94 pub unsafe fn reshaped_with_geometry_unchecked(
96 &self,
97 shape: impl Into<TVec<usize>>,
98 strides: impl Into<TVec<isize>>,
99 ) -> Self {
100 match self {
101 DValue::Natural(t) | DValue::Reshaped { t, .. } => {
102 DValue::Reshaped { t: Arc::clone(t), strides: strides.into(), shape: shape.into() }
103 }
104 }
105 }
106}
107
108impl IntoTensor for DValue {
109 fn into_tensor(self) -> Tensor {
110 match self {
111 Self::Natural(t) => Arc::try_unwrap(t).unwrap_or_else(|t| (*t).clone()),
112 Self::Reshaped { t, shape, strides: _ } => {
113 let mut t = Arc::try_unwrap(t).unwrap_or_else(|t| (*t).clone());
114 t.set_shape(&shape).expect("Could not apply shape to reshaped GPU tensor");
115 t
116 }
117 }
118 }
119}
120
121impl From<Tensor> for DValue {
122 fn from(v: Tensor) -> Self {
123 Self::Natural(Arc::new(v))
124 }
125}
126
127impl From<Arc<Tensor>> for DValue {
128 fn from(v: Arc<Tensor>) -> Self {
129 Self::Natural(v)
130 }
131}
132
133#[derive(Debug, Clone)]
136pub struct OwnedDeviceTensor {
137 pub inner: DValue,
138 pub device_buffer: Box<dyn DeviceBuffer>,
139}
140
141impl Hash for OwnedDeviceTensor {
142 #[inline]
143 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
144 self.inner.hash(state)
145 }
146}
147
148impl OwnedDeviceTensor {
149 pub fn from_tensor<T: Into<DValue>>(tensor: T) -> TractResult<Self> {
151 let m_value: DValue = tensor.into();
152 let tensor_view = m_value.view();
153 ensure!(
154 DeviceTensor::is_supported_dt(tensor_view.datum_type()),
155 "Tensor of {:?} is not copied. No device buffer can be allocated for it.",
156 tensor_view.datum_type(),
157 );
158
159 let data_bytes = as_q40_tensor(tensor_view.tensor)
160 .map(|bqv| bqv.value.as_bytes())
161 .unwrap_or(tensor_view.tensor.as_bytes());
162
163 let device_buffer = get_context()?.buffer_from_slice(data_bytes);
164
165 Ok(OwnedDeviceTensor { inner: m_value, device_buffer })
166 }
167
168 #[inline]
169 pub fn shape(&self) -> &[usize] {
170 self.inner.shape()
171 }
172
173 #[inline]
175 #[allow(clippy::len_without_is_empty)]
176 pub fn len(&self) -> usize {
177 self.shape().iter().product()
178 }
179
180 #[inline]
182 pub fn strides(&self) -> &[isize] {
183 match &self.inner {
184 DValue::Natural(t) => t.strides(),
185 DValue::Reshaped { strides, .. } => strides,
186 }
187 }
188
189 #[inline]
191 pub fn device_buffer(&self) -> &dyn DeviceBuffer {
192 &(*self.device_buffer)
193 }
194
195 pub fn device_buffer_ptr(&self) -> *const c_void {
196 self.device_buffer.ptr()
197 }
198
199 #[inline]
201 pub fn buffer_offset<I: Copy + 'static>(&self) -> I
202 where
203 usize: AsPrimitive<I>,
204 {
205 0usize.as_()
207 }
208
209 #[inline]
211 pub fn reshaped(&self, shape: impl Into<TVec<usize>>) -> TractResult<Self> {
212 Ok(Self { inner: self.inner.reshaped(shape)?, device_buffer: self.device_buffer.clone() })
213 }
214
215 #[inline]
217 pub fn restrided(&self, strides: impl Into<TVec<isize>>) -> TractResult<Self> {
218 Ok(Self {
219 inner: self.inner.restrided(strides)?,
220 device_buffer: self.device_buffer.clone(),
221 })
222 }
223
224 #[inline]
226 pub unsafe fn reshaped_with_geometry_unchecked(
227 &self,
228 shape: impl Into<TVec<usize>>,
229 strides: impl Into<TVec<isize>>,
230 ) -> Self {
231 Self {
232 inner: self.inner.reshaped_with_geometry_unchecked(shape, strides),
233 device_buffer: self.device_buffer.clone(),
234 }
235 }
236
237 #[inline]
238 pub fn view(&self) -> TensorView<'_> {
239 self.inner.view()
240 }
241}
242
243impl Display for OwnedDeviceTensor {
244 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245 match &self.inner {
246 DValue::Natural(t) => {
247 let content = t.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
248 write!(f, "GPU {{ {content} }}")
249 }
250 DValue::Reshaped { t, shape, strides: _ } => {
251 let content = t.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
252 write!(f, "GPU reshaped: {:?} - {{ {content} }}", shape)
253 }
254 }
255 }
256}