Skip to main content

tract_data/tensor/
plain_view.rs

1use std::alloc::Layout;
2
3use ndarray::prelude::*;
4
5use crate::datum::{Datum, DatumType};
6use crate::internal::*;
7use crate::tensor::Tensor;
8
9use super::storage::PlainStorage;
10
11fn check_for_access<D: Datum>(dt: DatumType) -> TractResult<()> {
12    ensure!(
13        dt.unquantized() == D::datum_type().unquantized(),
14        "Tensor datum type error: tensor is {:?}, accessed as {:?}",
15        dt,
16        D::datum_type(),
17    );
18    Ok(())
19}
20
21/// Immutable view into a [`Tensor`] verified to have plain storage.
22///
23/// Construction is the single point of failure (`Tensor::as_plain()` returns
24/// `Option`). Once constructed, all data access is infallible with no
25/// `unwrap()`/`expect()` on the plain codepath.
26pub struct PlainView<'a> {
27    tensor: &'a Tensor,
28    storage: &'a PlainStorage,
29}
30
31impl<'a> PlainView<'a> {
32    /// Private constructor used by `Tensor::as_plain()`.
33    #[inline]
34    pub(crate) fn new(tensor: &'a Tensor, storage: &'a PlainStorage) -> Self {
35        PlainView { tensor, storage }
36    }
37
38    // -- Metadata (delegated to tensor) --
39
40    #[inline]
41    pub fn tensor(&self) -> &Tensor {
42        self.tensor
43    }
44
45    #[inline]
46    pub fn datum_type(&self) -> DatumType {
47        self.tensor.datum_type()
48    }
49
50    #[inline]
51    pub fn shape(&self) -> &[usize] {
52        self.tensor.shape()
53    }
54
55    #[inline]
56    pub fn strides(&self) -> &[isize] {
57        self.tensor.strides()
58    }
59
60    #[inline]
61    pub fn rank(&self) -> usize {
62        self.tensor.rank()
63    }
64
65    #[inline]
66    pub fn len(&self) -> usize {
67        self.tensor.len()
68    }
69
70    #[inline]
71    pub fn is_empty(&self) -> bool {
72        self.len() == 0
73    }
74
75    // -- Plain-specific (direct storage access, no dispatch) --
76
77    #[inline]
78    pub fn as_bytes(&self) -> &'a [u8] {
79        self.storage.as_bytes()
80    }
81
82    #[inline]
83    pub fn layout(&self) -> &Layout {
84        self.storage.layout()
85    }
86
87    // -- Typed access --
88    // TractResult is for datum-type check only, NOT plain check.
89
90    #[inline]
91    pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
92        check_for_access::<D>(self.datum_type())?;
93        unsafe { Ok(self.as_ptr_unchecked()) }
94    }
95
96    #[inline]
97    pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
98        self.storage.as_ptr() as *const D
99    }
100
101    #[inline]
102    pub fn as_slice<D: Datum>(&self) -> TractResult<&'a [D]> {
103        check_for_access::<D>(self.datum_type())?;
104        unsafe { Ok(self.as_slice_unchecked()) }
105    }
106
107    #[inline]
108    pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &'a [D] {
109        if self.storage.is_empty() {
110            &[]
111        } else {
112            unsafe { std::slice::from_raw_parts(self.as_ptr_unchecked(), self.len()) }
113        }
114    }
115
116    #[inline]
117    pub fn to_scalar<D: Datum>(&self) -> TractResult<&'a D> {
118        check_for_access::<D>(self.datum_type())?;
119        unsafe { Ok(self.to_scalar_unchecked()) }
120    }
121
122    #[inline]
123    pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &'a D {
124        unsafe { &*(self.storage.as_ptr() as *const D) }
125    }
126
127    #[inline]
128    pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'a, D>> {
129        check_for_access::<D>(self.datum_type())?;
130        unsafe { Ok(self.to_array_view_unchecked()) }
131    }
132
133    #[inline]
134    pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'a, D> {
135        if self.len() != 0 {
136            unsafe { ArrayViewD::from_shape_ptr(self.shape(), self.storage.as_ptr() as *const D) }
137        } else {
138            ArrayViewD::from_shape(self.shape(), &[]).unwrap()
139        }
140    }
141}
142
143/// Mutable view into a [`Tensor`] verified to have plain storage.
144///
145/// Fields are split to satisfy the borrow checker: mutable storage +
146/// immutable metadata borrowed from the same Tensor.
147pub struct PlainViewMut<'a> {
148    dt: DatumType,
149    shape: &'a [usize],
150    strides: &'a [isize],
151    len: usize,
152    storage: &'a mut PlainStorage,
153}
154
155impl<'a> PlainViewMut<'a> {
156    /// Private constructor used by `Tensor::as_plain_mut()`.
157    #[inline]
158    pub(crate) fn new(
159        dt: DatumType,
160        shape: &'a [usize],
161        strides: &'a [isize],
162        len: usize,
163        storage: &'a mut PlainStorage,
164    ) -> Self {
165        PlainViewMut { dt, shape, strides, len, storage }
166    }
167
168    // -- Metadata --
169
170    #[inline]
171    pub fn datum_type(&self) -> DatumType {
172        self.dt
173    }
174
175    #[inline]
176    pub fn shape(&self) -> &[usize] {
177        self.shape
178    }
179
180    #[inline]
181    pub fn strides(&self) -> &[isize] {
182        self.strides
183    }
184
185    #[inline]
186    pub fn rank(&self) -> usize {
187        self.shape.len()
188    }
189
190    #[inline]
191    pub fn len(&self) -> usize {
192        self.len
193    }
194
195    #[inline]
196    pub fn is_empty(&self) -> bool {
197        self.len() == 0
198    }
199
200    // -- Read access (same as PlainView, self.storage reborrows as &PlainStorage) --
201
202    #[inline]
203    pub fn as_bytes(&self) -> &[u8] {
204        self.storage.as_bytes()
205    }
206
207    #[inline]
208    pub fn layout(&self) -> &Layout {
209        self.storage.layout()
210    }
211
212    #[inline]
213    pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
214        check_for_access::<D>(self.dt)?;
215        unsafe { Ok(self.as_ptr_unchecked()) }
216    }
217
218    #[inline]
219    pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
220        self.storage.as_ptr() as *const D
221    }
222
223    #[inline]
224    pub fn as_slice<D: Datum>(&self) -> TractResult<&[D]> {
225        check_for_access::<D>(self.dt)?;
226        unsafe { Ok(self.as_slice_unchecked()) }
227    }
228
229    #[inline]
230    pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &[D] {
231        if self.storage.is_empty() {
232            &[]
233        } else {
234            unsafe { std::slice::from_raw_parts(self.as_ptr_unchecked(), self.len) }
235        }
236    }
237
238    #[inline]
239    pub fn to_scalar<D: Datum>(&self) -> TractResult<&D> {
240        check_for_access::<D>(self.dt)?;
241        unsafe { Ok(self.to_scalar_unchecked()) }
242    }
243
244    #[inline]
245    pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &D {
246        unsafe { &*(self.storage.as_ptr() as *const D) }
247    }
248
249    #[inline]
250    pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'_, D>> {
251        check_for_access::<D>(self.dt)?;
252        unsafe { Ok(self.to_array_view_unchecked()) }
253    }
254
255    #[inline]
256    pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'_, D> {
257        if self.len != 0 {
258            unsafe { ArrayViewD::from_shape_ptr(self.shape, self.storage.as_ptr() as *const D) }
259        } else {
260            ArrayViewD::from_shape(self.shape, &[]).unwrap()
261        }
262    }
263
264    // -- Mutable access --
265
266    #[inline]
267    pub fn as_bytes_mut(&mut self) -> &mut [u8] {
268        self.storage.as_bytes_mut()
269    }
270
271    #[inline]
272    pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
273        check_for_access::<D>(self.dt)?;
274        unsafe { Ok(self.as_ptr_mut_unchecked()) }
275    }
276
277    #[inline]
278    pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
279        self.storage.as_mut_ptr() as *mut D
280    }
281
282    #[inline]
283    pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
284        check_for_access::<D>(self.dt)?;
285        unsafe { Ok(self.as_slice_mut_unchecked()) }
286    }
287
288    #[inline]
289    pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
290        if self.storage.is_empty() {
291            &mut []
292        } else {
293            let len = self.len;
294            unsafe { std::slice::from_raw_parts_mut(self.as_ptr_mut_unchecked(), len) }
295        }
296    }
297
298    #[inline]
299    pub fn to_scalar_mut<D: Datum>(&mut self) -> TractResult<&mut D> {
300        check_for_access::<D>(self.dt)?;
301        unsafe { Ok(self.to_scalar_mut_unchecked()) }
302    }
303
304    #[inline]
305    pub unsafe fn to_scalar_mut_unchecked<D: Datum>(&mut self) -> &mut D {
306        unsafe { &mut *(self.storage.as_mut_ptr() as *mut D) }
307    }
308
309    #[inline]
310    pub fn to_array_view_mut<D: Datum>(&mut self) -> TractResult<ArrayViewMutD<'_, D>> {
311        check_for_access::<D>(self.dt)?;
312        unsafe { Ok(self.to_array_view_mut_unchecked()) }
313    }
314
315    #[inline]
316    pub unsafe fn to_array_view_mut_unchecked<D: Datum>(&mut self) -> ArrayViewMutD<'_, D> {
317        if self.len != 0 {
318            unsafe {
319                ArrayViewMutD::from_shape_ptr(self.shape, self.storage.as_mut_ptr() as *mut D)
320            }
321        } else {
322            ArrayViewMutD::from_shape(self.shape, &mut []).unwrap()
323        }
324    }
325}