tract_gpu/tensor/
arena_view.rs1use num_traits::AsPrimitive;
2use std::ffi::c_void;
3use std::fmt::Display;
4use tract_core::internal::*;
5use tract_core::tract_linalg::block_quant::{BlockQuantFact, BlockQuantStorage};
6
7use crate::device::{DeviceBuffer, get_context};
8use crate::utils::check_strides_validity;
9
10use super::OwnedDeviceTensor;
11
12#[derive(Debug, Clone, Hash, PartialEq, Eq)]
13pub struct DeviceArenaView {
14 pub(crate) arena: Arc<Box<dyn OwnedDeviceTensor>>,
15 pub(crate) dt: DatumType,
16 pub(crate) len: usize,
17 pub(crate) shape: TVec<usize>,
18 pub(crate) strides: TVec<isize>,
19 pub(crate) offset_bytes: usize,
20 pub(crate) exotic_fact: Option<Box<dyn ExoticFact>>,
21}
22
23impl DeviceArenaView {
24 #[inline]
25 pub fn shape(&self) -> &[usize] {
26 self.shape.as_slice()
27 }
28
29 #[inline]
31 pub fn datum_type(&self) -> DatumType {
32 self.dt
33 }
34
35 #[inline]
36 pub fn strides(&self) -> &[isize] {
37 self.strides.as_slice()
38 }
39
40 pub fn device_buffer(&self) -> &dyn DeviceBuffer {
42 self.arena.device_buffer()
43 }
44
45 pub fn device_buffer_ptr(&self) -> *const c_void {
46 self.arena.device_buffer().ptr()
47 }
48
49 pub fn buffer_offset<I: Copy + 'static>(&self) -> I
51 where
52 usize: AsPrimitive<I>,
53 {
54 self.offset_bytes.as_()
55 }
56
57 pub fn exotic_fact(&self) -> Option<&dyn ExoticFact> {
58 self.exotic_fact.as_deref()
59 }
60
61 #[inline]
63 #[allow(clippy::len_without_is_empty)]
64 pub fn len(&self) -> usize {
65 self.len
66 }
67
68 pub fn as_bytes(&self) -> Vec<u8> {
69 let len = if let Some(of) = &self.exotic_fact {
70 of.mem_size().as_i64().unwrap() as usize
71 } else {
72 self.len() * self.dt.size_of()
73 };
74 self.arena.get_bytes_slice(self.offset_bytes, len)
75 }
76
77 pub fn reshaped(&self, shape: impl Into<TVec<usize>>) -> TractResult<Self> {
79 ensure!(self.exotic_fact.is_none(), "Can't reshape exotic tensor");
80 let shape = shape.into();
81 if self.len() != shape.iter().product::<usize>() {
82 bail!("Invalid reshape {:?} to {:?}", self.shape(), shape);
83 }
84 if shape.as_slice() != self.shape() {
85 Ok(Self {
86 arena: Arc::clone(&self.arena),
87 dt: self.dt,
88 len: self.len,
89 strides: Tensor::natural_strides(&shape),
90 shape,
91 offset_bytes: self.offset_bytes,
92 exotic_fact: None,
93 })
94 } else {
95 Ok(self.clone())
96 }
97 }
98
99 pub fn restrided(&self, strides: impl Into<TVec<isize>>) -> TractResult<Self> {
100 ensure!(self.exotic_fact.is_none(), "Can't restride exotic tensor");
101 let strides = strides.into();
102 check_strides_validity(self.shape().into(), strides.clone())?;
103
104 if strides.as_slice() != self.strides() {
105 Ok(Self {
106 arena: Arc::clone(&self.arena),
107 dt: self.dt,
108 len: self.len,
109 strides,
110 shape: self.shape.clone(),
111 offset_bytes: self.offset_bytes,
112 exotic_fact: None,
113 })
114 } else {
115 Ok(self.clone())
116 }
117 }
118
119 pub fn to_host(&self) -> TractResult<Tensor> {
120 get_context()?.synchronize()?;
121 let content = self.as_bytes();
122 unsafe {
123 if let Some(bqf) =
124 self.exotic_fact.as_ref().and_then(|of| of.downcast_ref::<BlockQuantFact>())
125 {
126 Ok(BlockQuantStorage::new(
127 bqf.format.clone(),
128 bqf.m(),
129 bqf.k(),
130 Arc::new(Blob::from_bytes(&content)?),
131 )?
132 .into_tensor_with_shape(self.dt, bqf.shape()))
133 } else {
134 Tensor::from_raw_dt(self.dt, &self.shape, &content)
135 }
136 }
137 }
138}
139
140impl Display for DeviceArenaView {
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 let content = self
143 .clone()
144 .to_host()
145 .unwrap()
146 .dump(false)
147 .unwrap_or_else(|e| format!("Error : {e:?}"));
148 write!(f, "DeviceArenaView: {{ {content} }}")
149 }
150}