Skip to main content

tract_gpu/tensor/
arena_view.rs

1use num_traits::AsPrimitive;
2use std::ffi::c_void;
3use std::fmt::Display;
4use tract_core::internal::*;
5use tract_core::tract_linalg::block_quant::{BlockQuantFact, BlockQuantStorage};
6
7use crate::device::{DeviceBuffer, get_context};
8use crate::utils::check_strides_validity;
9
10use super::OwnedDeviceTensor;
11
12#[derive(Debug, Clone, Hash, PartialEq, Eq)]
13pub struct DeviceArenaView {
14    pub(crate) arena: Arc<Box<dyn OwnedDeviceTensor>>,
15    pub(crate) dt: DatumType,
16    pub(crate) len: usize,
17    pub(crate) shape: TVec<usize>,
18    pub(crate) strides: TVec<isize>,
19    pub(crate) offset_bytes: usize,
20    pub(crate) exotic_fact: Option<Box<dyn ExoticFact>>,
21}
22
23impl DeviceArenaView {
24    #[inline]
25    pub fn shape(&self) -> &[usize] {
26        self.shape.as_slice()
27    }
28
29    /// Get the datum type of the tensor.
30    #[inline]
31    pub fn datum_type(&self) -> DatumType {
32        self.dt
33    }
34
35    #[inline]
36    pub fn strides(&self) -> &[isize] {
37        self.strides.as_slice()
38    }
39
40    /// Get underlying inner device buffer.
41    pub fn device_buffer(&self) -> &dyn DeviceBuffer {
42        self.arena.device_buffer()
43    }
44
45    pub fn device_buffer_ptr(&self) -> *const c_void {
46        self.arena.device_buffer().ptr()
47    }
48
49    /// Get underlying inner device buffer offset
50    pub fn buffer_offset<I: Copy + 'static>(&self) -> I
51    where
52        usize: AsPrimitive<I>,
53    {
54        self.offset_bytes.as_()
55    }
56
57    pub fn exotic_fact(&self) -> Option<&dyn ExoticFact> {
58        self.exotic_fact.as_deref()
59    }
60
61    /// Get the number of values in the tensor.
62    #[inline]
63    #[allow(clippy::len_without_is_empty)]
64    pub fn len(&self) -> usize {
65        self.len
66    }
67
68    pub fn as_bytes(&self) -> Vec<u8> {
69        let len = if let Some(of) = &self.exotic_fact {
70            of.mem_size().as_i64().unwrap() as usize
71        } else {
72            self.len() * self.dt.size_of()
73        };
74        self.arena.get_bytes_slice(self.offset_bytes, len)
75    }
76
77    /// Reshaped tensor with given shape.
78    pub fn reshaped(&self, shape: impl Into<TVec<usize>>) -> TractResult<Self> {
79        ensure!(self.exotic_fact.is_none(), "Can't reshape exotic tensor");
80        let shape = shape.into();
81        if self.len() != shape.iter().product::<usize>() {
82            bail!("Invalid reshape {:?} to {:?}", self.shape(), shape);
83        }
84        if shape.as_slice() != self.shape() {
85            Ok(Self {
86                arena: Arc::clone(&self.arena),
87                dt: self.dt,
88                len: self.len,
89                strides: Tensor::natural_strides(&shape),
90                shape,
91                offset_bytes: self.offset_bytes,
92                exotic_fact: None,
93            })
94        } else {
95            Ok(self.clone())
96        }
97    }
98
99    pub fn restrided(&self, strides: impl Into<TVec<isize>>) -> TractResult<Self> {
100        ensure!(self.exotic_fact.is_none(), "Can't restride exotic tensor");
101        let strides = strides.into();
102        check_strides_validity(self.shape().into(), strides.clone())?;
103
104        if strides.as_slice() != self.strides() {
105            Ok(Self {
106                arena: Arc::clone(&self.arena),
107                dt: self.dt,
108                len: self.len,
109                strides,
110                shape: self.shape.clone(),
111                offset_bytes: self.offset_bytes,
112                exotic_fact: None,
113            })
114        } else {
115            Ok(self.clone())
116        }
117    }
118
119    pub fn to_host(&self) -> TractResult<Tensor> {
120        get_context()?.synchronize()?;
121        let content = self.as_bytes();
122        unsafe {
123            if let Some(bqf) =
124                self.exotic_fact.as_ref().and_then(|of| of.downcast_ref::<BlockQuantFact>())
125            {
126                Ok(BlockQuantStorage::new(
127                    bqf.format.clone(),
128                    bqf.m(),
129                    bqf.k(),
130                    Arc::new(Blob::from_bytes(&content)?),
131                )?
132                .into_tensor_with_shape(self.dt, bqf.shape()))
133            } else {
134                Tensor::from_raw_dt(self.dt, &self.shape, &content)
135            }
136        }
137    }
138}
139
140impl Display for DeviceArenaView {
141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        let content = self
143            .clone()
144            .to_host()
145            .unwrap()
146            .dump(false)
147            .unwrap_or_else(|e| format!("Error : {e:?}"));
148        write!(f, "DeviceArenaView: {{ {content} }}")
149    }
150}