tract_gpu/tensor/
arena_view.rs

1use num_traits::AsPrimitive;
2use std::ffi::c_void;
3use std::fmt::Display;
4use tract_core::internal::*;
5
6use crate::device::DeviceBuffer;
7use crate::utils::check_strides_validity;
8
9use super::OwnedDeviceTensor;
10
11#[derive(Debug, Clone, Hash)]
12pub struct DeviceArenaView {
13    pub(crate) arena: Arc<Box<dyn OwnedDeviceTensor>>,
14    pub(crate) dt: DatumType,
15    pub(crate) len: usize,
16    pub(crate) shape: TVec<usize>,
17    pub(crate) strides: TVec<isize>,
18    pub(crate) offset_bytes: usize,
19}
20
21impl DeviceArenaView {
22    #[inline]
23    pub fn shape(&self) -> &[usize] {
24        self.shape.as_slice()
25    }
26
27    /// Get the datum type of the tensor.
28    #[inline]
29    pub fn datum_type(&self) -> DatumType {
30        self.dt
31    }
32
33    #[inline]
34    pub fn strides(&self) -> &[isize] {
35        self.strides.as_slice()
36    }
37
38    /// Get underlying inner device buffer.
39    pub fn device_buffer(&self) -> &dyn DeviceBuffer {
40        self.arena.device_buffer()
41    }
42
43    pub fn device_buffer_ptr(&self) -> *const c_void {
44        self.arena.device_buffer().ptr()
45    }
46
47    /// Get underlying inner device buffer offset
48    pub fn buffer_offset<I: Copy + 'static>(&self) -> I
49    where
50        usize: AsPrimitive<I>,
51    {
52        self.offset_bytes.as_()
53    }
54
55    /// Get the number of values in the tensor.
56    #[inline]
57    #[allow(clippy::len_without_is_empty)]
58    pub fn len(&self) -> usize {
59        self.len
60    }
61
62    pub fn as_bytes(&self) -> &[u8] {
63        &self.arena.as_arc_tensor().unwrap().as_bytes()
64            [self.offset_bytes..self.offset_bytes + self.len() * self.dt.size_of()]
65    }
66
67    #[inline]
68    pub fn view(&self) -> TensorView<'_> {
69        unsafe {
70            TensorView::from_bytes(
71                self.arena.as_arc_tensor().unwrap(),
72                self.offset_bytes as _,
73                self.shape.as_slice(),
74                self.strides.as_slice(),
75            )
76        }
77    }
78
79    /// Reshaped tensor with given shape.
80    pub fn reshaped(&self, shape: impl Into<TVec<usize>>) -> TractResult<Self> {
81        let shape = shape.into();
82        if self.len() != shape.iter().product::<usize>() {
83            bail!("Invalid reshape {:?} to {:?}", self.shape(), shape);
84        }
85        if shape.as_slice() != self.shape() {
86            Ok(Self {
87                arena: Arc::clone(&self.arena),
88                dt: self.dt,
89                len: self.len,
90                strides: Tensor::natural_strides(&shape),
91                shape,
92                offset_bytes: self.offset_bytes,
93            })
94        } else {
95            Ok(self.clone())
96        }
97    }
98
99    pub fn restrided(&self, strides: impl Into<TVec<isize>>) -> TractResult<Self> {
100        let strides = strides.into();
101        check_strides_validity(self.shape().into(), strides.clone())?;
102
103        if strides.as_slice() != self.strides() {
104            Ok(Self {
105                arena: Arc::clone(&self.arena),
106                dt: self.dt,
107                len: self.len,
108                strides,
109                shape: self.shape.clone(),
110                offset_bytes: self.offset_bytes,
111            })
112        } else {
113            Ok(self.clone())
114        }
115    }
116}
117
118impl Display for DeviceArenaView {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        let content =
121            self.clone().into_tensor().dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
122        write!(f, "DeviceArenaView: {{ {content} }}")
123    }
124}
125
126impl IntoTensor for DeviceArenaView {
127    fn into_tensor(self) -> Tensor {
128        unsafe {
129            Tensor::from_raw_dt(self.dt, &self.shape, self.as_bytes())
130                .expect("Could not transform a DeviceArenaView to tensor")
131        }
132    }
133}