1#![allow(clippy::missing_safety_doc)]
2#![allow(clippy::missing_transmute_annotations)]
3
4mod arena_view;
5mod owned;
6
7pub use arena_view::*;
8pub use owned::*;
9
10use num_traits::AsPrimitive;
11use std::ffi::c_void;
12use std::fmt::Display;
13use tract_core::internal::*;
14use tract_data::itertools::Itertools;
15
16use crate::device::{DeviceBuffer, get_context};
17
18#[derive(Debug, Clone, Hash)]
21pub enum DeviceTensor {
22 Owned(Box<dyn OwnedDeviceTensor>),
23 ArenaView(DeviceArenaView),
24}
25
26impl DeviceTensor {
27 pub const SUPPORTED_DT: [DatumType; 12] = [
28 DatumType::Bool,
29 DatumType::F32,
30 DatumType::F16,
31 DatumType::I8,
32 DatumType::U8,
33 DatumType::I16,
34 DatumType::U16,
35 DatumType::I32,
36 DatumType::U32,
37 DatumType::I64,
38 DatumType::U64,
39 DatumType::Opaque,
40 ];
41
42 pub fn tname(dt: DatumType) -> TractResult<&'static str> {
43 Ok(match dt {
44 DatumType::F32 => "f32",
45 DatumType::F16 => "f16",
46 DatumType::U8 => "u8",
47 DatumType::U16 => "u16",
48 DatumType::U32 => "u32",
49 DatumType::U64 => "u64",
50 DatumType::I8 => "i8",
51 DatumType::I16 => "i16",
52 DatumType::I32 => "i32",
53 DatumType::I64 => "i64",
54 DatumType::Bool => "bool",
55 DatumType::Opaque => "opaque",
56 _ => bail!("Unsupported dt {:?} for GPU Tensor", dt),
57 })
58 }
59
60 pub fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<DeviceTensor> {
62 Ok(DeviceTensor::Owned(get_context()?.uninitialized_device_tensor(shape, dt)?))
63 }
64
65 pub fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<DeviceTensor> {
66 Self::uninitialized_dt(T::datum_type(), shape)
67 }
68
69 pub fn uninitialized_opaque(opaque_fact: Box<dyn OpaqueFact>) -> TractResult<DeviceTensor> {
70 Ok(DeviceTensor::Owned(get_context()?.uninitialized_device_opaque_tensor(opaque_fact)?))
71 }
72 pub fn from_shape<T: Copy + Datum>(shape: &[usize], data: &[T]) -> TractResult<DeviceTensor> {
74 Tensor::from_shape(shape, data)?.into_device()
75 }
76
77 pub fn is_supported_dt(dt: DatumType) -> bool {
78 Self::SUPPORTED_DT.contains(&dt)
79 }
80
81 #[inline]
83 pub fn datum_type(&self) -> DatumType {
84 match self {
85 Self::Owned(owned) => owned.datum_type(),
86 Self::ArenaView(view) => view.datum_type(),
87 }
88 }
89
90 #[inline]
92 pub fn rank(&self) -> usize {
93 self.shape().len()
94 }
95
96 #[inline]
98 pub fn shape(&self) -> &[usize] {
99 match self {
100 Self::Owned(t) => t.shape(),
101 Self::ArenaView(t) => t.shape(),
102 }
103 }
104
105 #[inline]
107 #[allow(clippy::len_without_is_empty)]
108 pub fn len(&self) -> usize {
109 match self {
110 Self::Owned(t) => t.len(),
111 Self::ArenaView(t) => t.len(),
112 }
113 }
114
115 #[inline]
117 pub fn strides(&self) -> &[isize] {
118 match self {
119 Self::Owned(t) => t.strides(),
120 Self::ArenaView(t) => t.strides(),
121 }
122 }
123
124 pub fn device_buffer(&self) -> &dyn DeviceBuffer {
126 match self {
127 Self::Owned(t) => t.device_buffer(),
128 Self::ArenaView(t) => t.device_buffer(),
129 }
130 }
131
132 pub fn buffer_offset<I: Copy + 'static>(&self) -> I
134 where
135 usize: AsPrimitive<I>,
136 {
137 match self {
138 Self::Owned(_) => 0.as_(),
139 Self::ArenaView(t) => t.buffer_offset(),
140 }
141 }
142
143 pub fn device_buffer_ptr(&self) -> *const c_void {
144 match self {
145 Self::Owned(t) => t.device_buffer().ptr(),
146 Self::ArenaView(t) => t.device_buffer().ptr(),
147 }
148 }
149
150 pub fn description(&self) -> String {
152 format!("|{},{:?}|", self.shape().iter().join(","), self.datum_type(),)
153 }
154
155 pub fn reshaped(&self, shape: TVec<usize>) -> TractResult<Self> {
157 match self {
158 Self::Owned(t) => Ok(t.reshaped(shape)?),
159 Self::ArenaView(t) => Ok(Self::ArenaView(t.reshaped(shape)?)),
160 }
161 }
162
163 pub fn restrided(&self, strides: TVec<isize>) -> TractResult<Self> {
164 match self {
165 Self::Owned(t) => Ok(t.restrided(strides)?),
166 Self::ArenaView(t) => Ok(Self::ArenaView(t.restrided(strides)?)),
167 }
168 }
169
170 pub fn into_opaque_tensor(self) -> Tensor {
172 tensor0::<Opaque>(self.into())
173 }
174
175 pub fn to_host(&self) -> TractResult<Arc<Tensor>> {
178 get_context()?.synchronize()?;
179
180 Ok(match self {
181 Self::Owned(o) => o.to_host()?,
182 Self::ArenaView(v) => v.to_host()?.into(),
183 })
184 }
185}
186
187impl Display for DeviceTensor {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 match self {
190 Self::Owned(o) => o.fmt(f),
191 Self::ArenaView(v) => {
192 let content =
193 v.to_host().unwrap().dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
194 write!(f, "ArenaView: {{ {content} }}")
195 }
196 }
197 }
198}
199
200pub trait IntoDevice<T> {
201 fn into_device(self) -> TractResult<T>;
202}
203
204impl IntoDevice<DeviceTensor> for Tensor {
205 fn into_device(self) -> TractResult<DeviceTensor> {
206 Ok(DeviceTensor::Owned(get_context()?.tensor_to_device(self.into_tvalue())?))
207 }
208}
209
210impl IntoDevice<DeviceTensor> for Arc<Tensor> {
211 fn into_device(self) -> TractResult<DeviceTensor> {
212 Ok(DeviceTensor::Owned(get_context()?.tensor_to_device(self.into_tvalue())?))
213 }
214}
215
216impl From<DeviceTensor> for Opaque {
217 fn from(value: DeviceTensor) -> Self {
218 Opaque(Arc::new(value))
219 }
220}
221
222impl From<DeviceArenaView> for DeviceTensor {
223 fn from(view: DeviceArenaView) -> Self {
224 Self::ArenaView(view)
225 }
226}
227
228impl OpaquePayload for DeviceTensor {
229 fn same_as(&self, other: &dyn OpaquePayload) -> bool {
230 other
231 .downcast_ref::<Self>()
232 .is_some_and(|other| self.device_buffer_ptr() == other.device_buffer_ptr())
233 }
234
235 fn clarify_to_tensor(&self) -> TractResult<Option<Arc<Tensor>>> {
236 Ok(Some(self.to_host()?))
237 }
238}
239
240pub trait DeviceTensorExt {
241 fn to_device_tensor(&self) -> TractResult<&DeviceTensor>;
242 fn as_device_tensor(&self) -> Option<&DeviceTensor>;
243 fn to_device_tensor_mut(&mut self) -> TractResult<&mut DeviceTensor>;
244 fn as_device_tensor_mut(&mut self) -> Option<&mut DeviceTensor>;
245}
246
247impl DeviceTensorExt for Tensor {
248 fn to_device_tensor_mut(&mut self) -> TractResult<&mut DeviceTensor> {
249 let opaque = self.to_scalar_mut::<Opaque>()?;
250 opaque.downcast_mut::<DeviceTensor>().ok_or_else(|| {
251 anyhow::anyhow!("Could convert opaque tensor to mutable reference on a device tensor")
252 })
253 }
254
255 fn as_device_tensor_mut(&mut self) -> Option<&mut DeviceTensor> {
256 let opaque = self.to_scalar_mut::<Opaque>().ok()?;
257 opaque.downcast_mut::<DeviceTensor>()
258 }
259
260 fn to_device_tensor(&self) -> TractResult<&DeviceTensor> {
261 let opaque = self.to_scalar::<Opaque>()?;
262 opaque.downcast_ref::<DeviceTensor>().ok_or_else(|| {
263 anyhow::anyhow!("Could convert opaque tensor to reference on a device tensor")
264 })
265 }
266
267 fn as_device_tensor(&self) -> Option<&DeviceTensor> {
268 let opaque = self.to_scalar::<Opaque>().ok()?;
269 opaque.downcast_ref::<DeviceTensor>()
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_device_tensor() -> TractResult<()> {
279 let a = DeviceTensor::from_shape(&[1], &[0f32])?;
280 assert_eq!(a.to_host()?.as_slice::<f32>()?, &[0.0]);
281 Ok(())
282 }
283}