Skip to main content

tract_gpu/tensor/
arena_view.rs

1use num_traits::AsPrimitive;
2use std::ffi::c_void;
3use std::fmt::Display;
4use tract_core::internal::*;
5
6use crate::device::{DeviceBuffer, get_context};
7use crate::utils::check_strides_validity;
8
9use super::OwnedDeviceTensor;
10
11#[derive(Debug, Clone, Hash)]
12pub struct DeviceArenaView {
13    pub(crate) arena: Arc<Box<dyn OwnedDeviceTensor>>,
14    pub(crate) dt: DatumType,
15    pub(crate) len: usize,
16    pub(crate) shape: TVec<usize>,
17    pub(crate) strides: TVec<isize>,
18    pub(crate) offset_bytes: usize,
19    pub(crate) opaque_fact: Option<Box<dyn OpaqueFact>>,
20}
21
22impl DeviceArenaView {
23    #[inline]
24    pub fn shape(&self) -> &[usize] {
25        self.shape.as_slice()
26    }
27
28    /// Get the datum type of the tensor.
29    #[inline]
30    pub fn datum_type(&self) -> DatumType {
31        self.dt
32    }
33
34    #[inline]
35    pub fn strides(&self) -> &[isize] {
36        self.strides.as_slice()
37    }
38
39    /// Get underlying inner device buffer.
40    pub fn device_buffer(&self) -> &dyn DeviceBuffer {
41        self.arena.device_buffer()
42    }
43
44    pub fn device_buffer_ptr(&self) -> *const c_void {
45        self.arena.device_buffer().ptr()
46    }
47
48    /// Get underlying inner device buffer offset
49    pub fn buffer_offset<I: Copy + 'static>(&self) -> I
50    where
51        usize: AsPrimitive<I>,
52    {
53        self.offset_bytes.as_()
54    }
55
56    pub fn opaque_fact(&self) -> Option<&dyn OpaqueFact> {
57        self.opaque_fact.as_deref()
58    }
59
60    /// Get the number of values in the tensor.
61    #[inline]
62    #[allow(clippy::len_without_is_empty)]
63    pub fn len(&self) -> usize {
64        self.len
65    }
66
67    pub fn as_bytes(&self) -> Vec<u8> {
68        let len = if let Some(of) = &self.opaque_fact {
69            of.mem_size().as_i64().unwrap() as usize
70        } else {
71            self.len() * self.dt.size_of()
72        };
73        self.arena.get_bytes_slice(self.offset_bytes, len)
74    }
75
76    /// Reshaped tensor with given shape.
77    pub fn reshaped(&self, shape: impl Into<TVec<usize>>) -> TractResult<Self> {
78        ensure!(self.opaque_fact.is_none(), "Can't reshape opaque tensor");
79        let shape = shape.into();
80        if self.len() != shape.iter().product::<usize>() {
81            bail!("Invalid reshape {:?} to {:?}", self.shape(), shape);
82        }
83        if shape.as_slice() != self.shape() {
84            Ok(Self {
85                arena: Arc::clone(&self.arena),
86                dt: self.dt,
87                len: self.len,
88                strides: Tensor::natural_strides(&shape),
89                shape,
90                offset_bytes: self.offset_bytes,
91                opaque_fact: None,
92            })
93        } else {
94            Ok(self.clone())
95        }
96    }
97
98    pub fn restrided(&self, strides: impl Into<TVec<isize>>) -> TractResult<Self> {
99        ensure!(self.opaque_fact.is_none(), "Can't restride opaque tensor");
100        let strides = strides.into();
101        check_strides_validity(self.shape().into(), strides.clone())?;
102
103        if strides.as_slice() != self.strides() {
104            Ok(Self {
105                arena: Arc::clone(&self.arena),
106                dt: self.dt,
107                len: self.len,
108                strides,
109                shape: self.shape.clone(),
110                offset_bytes: self.offset_bytes,
111                opaque_fact: None,
112            })
113        } else {
114            Ok(self.clone())
115        }
116    }
117
118    pub fn to_host(&self) -> TractResult<Tensor> {
119        get_context()?.synchronize()?;
120        let content = self.as_bytes();
121        unsafe {
122            if self.dt == DatumType::Opaque {
123                ensure!(self.len == 1, "Expected scalar Opaque");
124                Ok(tensor0(Opaque(Arc::new(BlobWithFact {
125                    fact: self
126                        .opaque_fact
127                        .clone()
128                        .context("Expected Opaque Fact for Opaque ArenaView")?,
129                    value: Arc::new(Blob::from_bytes(&content)?),
130                }))))
131            } else {
132                Tensor::from_raw_dt(self.dt, &self.shape, &content)
133            }
134        }
135    }
136}
137
138impl Display for DeviceArenaView {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        let content = self
141            .clone()
142            .to_host()
143            .unwrap()
144            .dump(false)
145            .unwrap_or_else(|e| format!("Error : {e:?}"));
146        write!(f, "DeviceArenaView: {{ {content} }}")
147    }
148}