tract_data/tensor/
dense_view.rs1use std::alloc::Layout;
2
3use ndarray::prelude::*;
4
5use crate::datum::{Datum, DatumType};
6use crate::internal::*;
7use crate::tensor::Tensor;
8
9use super::storage::DenseStorage;
10
11fn check_for_access<D: Datum>(dt: DatumType) -> TractResult<()> {
12 ensure!(
13 dt.unquantized() == D::datum_type().unquantized(),
14 "Tensor datum type error: tensor is {:?}, accessed as {:?}",
15 dt,
16 D::datum_type(),
17 );
18 Ok(())
19}
20
21pub struct DenseView<'a> {
27 tensor: &'a Tensor,
28 storage: &'a DenseStorage,
29}
30
31impl<'a> DenseView<'a> {
32 #[inline]
34 pub(crate) fn new(tensor: &'a Tensor, storage: &'a DenseStorage) -> Self {
35 DenseView { tensor, storage }
36 }
37
38 #[inline]
41 pub fn tensor(&self) -> &Tensor {
42 self.tensor
43 }
44
45 #[inline]
46 pub fn datum_type(&self) -> DatumType {
47 self.tensor.datum_type()
48 }
49
50 #[inline]
51 pub fn shape(&self) -> &[usize] {
52 self.tensor.shape()
53 }
54
55 #[inline]
56 pub fn strides(&self) -> &[isize] {
57 self.tensor.strides()
58 }
59
60 #[inline]
61 pub fn rank(&self) -> usize {
62 self.tensor.rank()
63 }
64
65 #[inline]
66 pub fn len(&self) -> usize {
67 self.tensor.len()
68 }
69
70 #[inline]
73 pub fn as_bytes(&self) -> &'a [u8] {
74 self.storage.as_bytes()
75 }
76
77 #[inline]
78 pub fn layout(&self) -> &Layout {
79 self.storage.layout()
80 }
81
82 #[inline]
86 pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
87 check_for_access::<D>(self.datum_type())?;
88 unsafe { Ok(self.as_ptr_unchecked()) }
89 }
90
91 #[inline]
92 pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
93 self.storage.as_ptr() as *const D
94 }
95
96 #[inline]
97 pub fn as_slice<D: Datum>(&self) -> TractResult<&'a [D]> {
98 check_for_access::<D>(self.datum_type())?;
99 unsafe { Ok(self.as_slice_unchecked()) }
100 }
101
102 #[inline]
103 pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &'a [D] {
104 if self.storage.is_empty() {
105 &[]
106 } else {
107 unsafe { std::slice::from_raw_parts(self.as_ptr_unchecked(), self.len()) }
108 }
109 }
110
111 #[inline]
112 pub fn to_scalar<D: Datum>(&self) -> TractResult<&'a D> {
113 check_for_access::<D>(self.datum_type())?;
114 unsafe { Ok(self.to_scalar_unchecked()) }
115 }
116
117 #[inline]
118 pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &'a D {
119 unsafe { &*(self.storage.as_ptr() as *const D) }
120 }
121
122 #[inline]
123 pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'a, D>> {
124 check_for_access::<D>(self.datum_type())?;
125 unsafe { Ok(self.to_array_view_unchecked()) }
126 }
127
128 #[inline]
129 pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'a, D> {
130 if self.len() != 0 {
131 unsafe { ArrayViewD::from_shape_ptr(self.shape(), self.storage.as_ptr() as *const D) }
132 } else {
133 ArrayViewD::from_shape(self.shape(), &[]).unwrap()
134 }
135 }
136}
137
138pub struct DenseViewMut<'a> {
143 dt: DatumType,
144 shape: &'a [usize],
145 strides: &'a [isize],
146 len: usize,
147 storage: &'a mut DenseStorage,
148}
149
150impl<'a> DenseViewMut<'a> {
151 #[inline]
153 pub(crate) fn new(
154 dt: DatumType,
155 shape: &'a [usize],
156 strides: &'a [isize],
157 len: usize,
158 storage: &'a mut DenseStorage,
159 ) -> Self {
160 DenseViewMut { dt, shape, strides, len, storage }
161 }
162
163 #[inline]
166 pub fn datum_type(&self) -> DatumType {
167 self.dt
168 }
169
170 #[inline]
171 pub fn shape(&self) -> &[usize] {
172 self.shape
173 }
174
175 #[inline]
176 pub fn strides(&self) -> &[isize] {
177 self.strides
178 }
179
180 #[inline]
181 pub fn rank(&self) -> usize {
182 self.shape.len()
183 }
184
185 #[inline]
186 pub fn len(&self) -> usize {
187 self.len
188 }
189
190 #[inline]
193 pub fn as_bytes(&self) -> &[u8] {
194 self.storage.as_bytes()
195 }
196
197 #[inline]
198 pub fn layout(&self) -> &Layout {
199 self.storage.layout()
200 }
201
202 #[inline]
203 pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
204 check_for_access::<D>(self.dt)?;
205 unsafe { Ok(self.as_ptr_unchecked()) }
206 }
207
208 #[inline]
209 pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
210 self.storage.as_ptr() as *const D
211 }
212
213 #[inline]
214 pub fn as_slice<D: Datum>(&self) -> TractResult<&[D]> {
215 check_for_access::<D>(self.dt)?;
216 unsafe { Ok(self.as_slice_unchecked()) }
217 }
218
219 #[inline]
220 pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &[D] {
221 if self.storage.is_empty() {
222 &[]
223 } else {
224 unsafe { std::slice::from_raw_parts(self.as_ptr_unchecked(), self.len) }
225 }
226 }
227
228 #[inline]
229 pub fn to_scalar<D: Datum>(&self) -> TractResult<&D> {
230 check_for_access::<D>(self.dt)?;
231 unsafe { Ok(self.to_scalar_unchecked()) }
232 }
233
234 #[inline]
235 pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &D {
236 unsafe { &*(self.storage.as_ptr() as *const D) }
237 }
238
239 #[inline]
240 pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'_, D>> {
241 check_for_access::<D>(self.dt)?;
242 unsafe { Ok(self.to_array_view_unchecked()) }
243 }
244
245 #[inline]
246 pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'_, D> {
247 if self.len != 0 {
248 unsafe { ArrayViewD::from_shape_ptr(self.shape, self.storage.as_ptr() as *const D) }
249 } else {
250 ArrayViewD::from_shape(self.shape, &[]).unwrap()
251 }
252 }
253
254 #[inline]
257 pub fn as_bytes_mut(&mut self) -> &mut [u8] {
258 self.storage.as_bytes_mut()
259 }
260
261 #[inline]
262 pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
263 check_for_access::<D>(self.dt)?;
264 unsafe { Ok(self.as_ptr_mut_unchecked()) }
265 }
266
267 #[inline]
268 pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
269 self.storage.as_mut_ptr() as *mut D
270 }
271
272 #[inline]
273 pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
274 check_for_access::<D>(self.dt)?;
275 unsafe { Ok(self.as_slice_mut_unchecked()) }
276 }
277
278 #[inline]
279 pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
280 if self.storage.is_empty() {
281 &mut []
282 } else {
283 let len = self.len;
284 unsafe { std::slice::from_raw_parts_mut(self.as_ptr_mut_unchecked(), len) }
285 }
286 }
287
288 #[inline]
289 pub fn to_scalar_mut<D: Datum>(&mut self) -> TractResult<&mut D> {
290 check_for_access::<D>(self.dt)?;
291 unsafe { Ok(self.to_scalar_mut_unchecked()) }
292 }
293
294 #[inline]
295 pub unsafe fn to_scalar_mut_unchecked<D: Datum>(&mut self) -> &mut D {
296 unsafe { &mut *(self.storage.as_mut_ptr() as *mut D) }
297 }
298
299 #[inline]
300 pub fn to_array_view_mut<D: Datum>(&mut self) -> TractResult<ArrayViewMutD<'_, D>> {
301 check_for_access::<D>(self.dt)?;
302 unsafe { Ok(self.to_array_view_mut_unchecked()) }
303 }
304
305 #[inline]
306 pub unsafe fn to_array_view_mut_unchecked<D: Datum>(&mut self) -> ArrayViewMutD<'_, D> {
307 if self.len != 0 {
308 unsafe {
309 ArrayViewMutD::from_shape_ptr(self.shape, self.storage.as_mut_ptr() as *mut D)
310 }
311 } else {
312 ArrayViewMutD::from_shape(self.shape, &mut []).unwrap()
313 }
314 }
315}