tract_data/tensor/
plain_view.rs1use std::alloc::Layout;
2
3use ndarray::prelude::*;
4
5use crate::datum::{Datum, DatumType};
6use crate::internal::*;
7use crate::tensor::Tensor;
8
9use super::storage::PlainStorage;
10
11fn check_for_access<D: Datum>(dt: DatumType) -> TractResult<()> {
12 ensure!(
13 dt.unquantized() == D::datum_type().unquantized(),
14 "Tensor datum type error: tensor is {:?}, accessed as {:?}",
15 dt,
16 D::datum_type(),
17 );
18 Ok(())
19}
20
21pub struct PlainView<'a> {
27 tensor: &'a Tensor,
28 storage: &'a PlainStorage,
29}
30
31impl<'a> PlainView<'a> {
32 #[inline]
34 pub(crate) fn new(tensor: &'a Tensor, storage: &'a PlainStorage) -> Self {
35 PlainView { tensor, storage }
36 }
37
38 #[inline]
41 pub fn tensor(&self) -> &Tensor {
42 self.tensor
43 }
44
45 #[inline]
46 pub fn datum_type(&self) -> DatumType {
47 self.tensor.datum_type()
48 }
49
50 #[inline]
51 pub fn shape(&self) -> &[usize] {
52 self.tensor.shape()
53 }
54
55 #[inline]
56 pub fn strides(&self) -> &[isize] {
57 self.tensor.strides()
58 }
59
60 #[inline]
61 pub fn rank(&self) -> usize {
62 self.tensor.rank()
63 }
64
65 #[inline]
66 pub fn len(&self) -> usize {
67 self.tensor.len()
68 }
69
70 #[inline]
71 pub fn is_empty(&self) -> bool {
72 self.len() == 0
73 }
74
75 #[inline]
78 pub fn as_bytes(&self) -> &'a [u8] {
79 self.storage.as_bytes()
80 }
81
82 #[inline]
83 pub fn layout(&self) -> &Layout {
84 self.storage.layout()
85 }
86
87 #[inline]
91 pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
92 check_for_access::<D>(self.datum_type())?;
93 unsafe { Ok(self.as_ptr_unchecked()) }
94 }
95
96 #[inline]
97 pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
98 self.storage.as_ptr() as *const D
99 }
100
101 #[inline]
102 pub fn as_slice<D: Datum>(&self) -> TractResult<&'a [D]> {
103 check_for_access::<D>(self.datum_type())?;
104 unsafe { Ok(self.as_slice_unchecked()) }
105 }
106
107 #[inline]
108 pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &'a [D] {
109 if self.storage.is_empty() {
110 &[]
111 } else {
112 unsafe { std::slice::from_raw_parts(self.as_ptr_unchecked(), self.len()) }
113 }
114 }
115
116 #[inline]
117 pub fn to_scalar<D: Datum>(&self) -> TractResult<&'a D> {
118 check_for_access::<D>(self.datum_type())?;
119 unsafe { Ok(self.to_scalar_unchecked()) }
120 }
121
122 #[inline]
123 pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &'a D {
124 unsafe { &*(self.storage.as_ptr() as *const D) }
125 }
126
127 #[inline]
128 pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'a, D>> {
129 check_for_access::<D>(self.datum_type())?;
130 unsafe { Ok(self.to_array_view_unchecked()) }
131 }
132
133 #[inline]
134 pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'a, D> {
135 if self.len() != 0 {
136 unsafe { ArrayViewD::from_shape_ptr(self.shape(), self.storage.as_ptr() as *const D) }
137 } else {
138 ArrayViewD::from_shape(self.shape(), &[]).unwrap()
139 }
140 }
141}
142
143pub struct PlainViewMut<'a> {
148 dt: DatumType,
149 shape: &'a [usize],
150 strides: &'a [isize],
151 len: usize,
152 storage: &'a mut PlainStorage,
153}
154
155impl<'a> PlainViewMut<'a> {
156 #[inline]
158 pub(crate) fn new(
159 dt: DatumType,
160 shape: &'a [usize],
161 strides: &'a [isize],
162 len: usize,
163 storage: &'a mut PlainStorage,
164 ) -> Self {
165 PlainViewMut { dt, shape, strides, len, storage }
166 }
167
168 #[inline]
171 pub fn datum_type(&self) -> DatumType {
172 self.dt
173 }
174
175 #[inline]
176 pub fn shape(&self) -> &[usize] {
177 self.shape
178 }
179
180 #[inline]
181 pub fn strides(&self) -> &[isize] {
182 self.strides
183 }
184
185 #[inline]
186 pub fn rank(&self) -> usize {
187 self.shape.len()
188 }
189
190 #[inline]
191 pub fn len(&self) -> usize {
192 self.len
193 }
194
195 #[inline]
196 pub fn is_empty(&self) -> bool {
197 self.len() == 0
198 }
199
200 #[inline]
203 pub fn as_bytes(&self) -> &[u8] {
204 self.storage.as_bytes()
205 }
206
207 #[inline]
208 pub fn layout(&self) -> &Layout {
209 self.storage.layout()
210 }
211
212 #[inline]
213 pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
214 check_for_access::<D>(self.dt)?;
215 unsafe { Ok(self.as_ptr_unchecked()) }
216 }
217
218 #[inline]
219 pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
220 self.storage.as_ptr() as *const D
221 }
222
223 #[inline]
224 pub fn as_slice<D: Datum>(&self) -> TractResult<&[D]> {
225 check_for_access::<D>(self.dt)?;
226 unsafe { Ok(self.as_slice_unchecked()) }
227 }
228
229 #[inline]
230 pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &[D] {
231 if self.storage.is_empty() {
232 &[]
233 } else {
234 unsafe { std::slice::from_raw_parts(self.as_ptr_unchecked(), self.len) }
235 }
236 }
237
238 #[inline]
239 pub fn to_scalar<D: Datum>(&self) -> TractResult<&D> {
240 check_for_access::<D>(self.dt)?;
241 unsafe { Ok(self.to_scalar_unchecked()) }
242 }
243
244 #[inline]
245 pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &D {
246 unsafe { &*(self.storage.as_ptr() as *const D) }
247 }
248
249 #[inline]
250 pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'_, D>> {
251 check_for_access::<D>(self.dt)?;
252 unsafe { Ok(self.to_array_view_unchecked()) }
253 }
254
255 #[inline]
256 pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'_, D> {
257 if self.len != 0 {
258 unsafe { ArrayViewD::from_shape_ptr(self.shape, self.storage.as_ptr() as *const D) }
259 } else {
260 ArrayViewD::from_shape(self.shape, &[]).unwrap()
261 }
262 }
263
264 #[inline]
267 pub fn as_bytes_mut(&mut self) -> &mut [u8] {
268 self.storage.as_bytes_mut()
269 }
270
271 #[inline]
272 pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
273 check_for_access::<D>(self.dt)?;
274 unsafe { Ok(self.as_ptr_mut_unchecked()) }
275 }
276
277 #[inline]
278 pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
279 self.storage.as_mut_ptr() as *mut D
280 }
281
282 #[inline]
283 pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
284 check_for_access::<D>(self.dt)?;
285 unsafe { Ok(self.as_slice_mut_unchecked()) }
286 }
287
288 #[inline]
289 pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
290 if self.storage.is_empty() {
291 &mut []
292 } else {
293 let len = self.len;
294 unsafe { std::slice::from_raw_parts_mut(self.as_ptr_mut_unchecked(), len) }
295 }
296 }
297
298 #[inline]
299 pub fn to_scalar_mut<D: Datum>(&mut self) -> TractResult<&mut D> {
300 check_for_access::<D>(self.dt)?;
301 unsafe { Ok(self.to_scalar_mut_unchecked()) }
302 }
303
304 #[inline]
305 pub unsafe fn to_scalar_mut_unchecked<D: Datum>(&mut self) -> &mut D {
306 unsafe { &mut *(self.storage.as_mut_ptr() as *mut D) }
307 }
308
309 #[inline]
310 pub fn to_array_view_mut<D: Datum>(&mut self) -> TractResult<ArrayViewMutD<'_, D>> {
311 check_for_access::<D>(self.dt)?;
312 unsafe { Ok(self.to_array_view_mut_unchecked()) }
313 }
314
315 #[inline]
316 pub unsafe fn to_array_view_mut_unchecked<D: Datum>(&mut self) -> ArrayViewMutD<'_, D> {
317 if self.len != 0 {
318 unsafe {
319 ArrayViewMutD::from_shape_ptr(self.shape, self.storage.as_mut_ptr() as *mut D)
320 }
321 } else {
322 ArrayViewMutD::from_shape(self.shape, &mut []).unwrap()
323 }
324 }
325}