Skip to main content

tract_data/tensor/
dense_view.rs

1use std::alloc::Layout;
2
3use ndarray::prelude::*;
4
5use crate::datum::{Datum, DatumType};
6use crate::internal::*;
7use crate::tensor::Tensor;
8
9use super::storage::DenseStorage;
10
11fn check_for_access<D: Datum>(dt: DatumType) -> TractResult<()> {
12    ensure!(
13        dt.unquantized() == D::datum_type().unquantized(),
14        "Tensor datum type error: tensor is {:?}, accessed as {:?}",
15        dt,
16        D::datum_type(),
17    );
18    Ok(())
19}
20
21/// Immutable view into a [`Tensor`] verified to have dense storage.
22///
23/// Construction is the single point of failure (`Tensor::as_dense()` returns
24/// `Option`). Once constructed, all data access is infallible with no
25/// `unwrap()`/`expect()` on the dense codepath.
26pub struct DenseView<'a> {
27    tensor: &'a Tensor,
28    storage: &'a DenseStorage,
29}
30
31impl<'a> DenseView<'a> {
32    /// Private constructor used by `Tensor::as_dense()`.
33    #[inline]
34    pub(crate) fn new(tensor: &'a Tensor, storage: &'a DenseStorage) -> Self {
35        DenseView { tensor, storage }
36    }
37
38    // -- Metadata (delegated to tensor) --
39
40    #[inline]
41    pub fn tensor(&self) -> &Tensor {
42        self.tensor
43    }
44
45    #[inline]
46    pub fn datum_type(&self) -> DatumType {
47        self.tensor.datum_type()
48    }
49
50    #[inline]
51    pub fn shape(&self) -> &[usize] {
52        self.tensor.shape()
53    }
54
55    #[inline]
56    pub fn strides(&self) -> &[isize] {
57        self.tensor.strides()
58    }
59
60    #[inline]
61    pub fn rank(&self) -> usize {
62        self.tensor.rank()
63    }
64
65    #[inline]
66    pub fn len(&self) -> usize {
67        self.tensor.len()
68    }
69
70    // -- Dense-specific (direct storage access, no dispatch) --
71
72    #[inline]
73    pub fn as_bytes(&self) -> &'a [u8] {
74        self.storage.as_bytes()
75    }
76
77    #[inline]
78    pub fn layout(&self) -> &Layout {
79        self.storage.layout()
80    }
81
82    // -- Typed access --
83    // TractResult is for datum-type check only, NOT dense check.
84
85    #[inline]
86    pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
87        check_for_access::<D>(self.datum_type())?;
88        unsafe { Ok(self.as_ptr_unchecked()) }
89    }
90
91    #[inline]
92    pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
93        self.storage.as_ptr() as *const D
94    }
95
96    #[inline]
97    pub fn as_slice<D: Datum>(&self) -> TractResult<&'a [D]> {
98        check_for_access::<D>(self.datum_type())?;
99        unsafe { Ok(self.as_slice_unchecked()) }
100    }
101
102    #[inline]
103    pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &'a [D] {
104        if self.storage.is_empty() {
105            &[]
106        } else {
107            unsafe { std::slice::from_raw_parts(self.as_ptr_unchecked(), self.len()) }
108        }
109    }
110
111    #[inline]
112    pub fn to_scalar<D: Datum>(&self) -> TractResult<&'a D> {
113        check_for_access::<D>(self.datum_type())?;
114        unsafe { Ok(self.to_scalar_unchecked()) }
115    }
116
117    #[inline]
118    pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &'a D {
119        unsafe { &*(self.storage.as_ptr() as *const D) }
120    }
121
122    #[inline]
123    pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'a, D>> {
124        check_for_access::<D>(self.datum_type())?;
125        unsafe { Ok(self.to_array_view_unchecked()) }
126    }
127
128    #[inline]
129    pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'a, D> {
130        if self.len() != 0 {
131            unsafe { ArrayViewD::from_shape_ptr(self.shape(), self.storage.as_ptr() as *const D) }
132        } else {
133            ArrayViewD::from_shape(self.shape(), &[]).unwrap()
134        }
135    }
136}
137
138/// Mutable view into a [`Tensor`] verified to have dense storage.
139///
140/// Fields are split to satisfy the borrow checker: mutable storage +
141/// immutable metadata borrowed from the same Tensor.
142pub struct DenseViewMut<'a> {
143    dt: DatumType,
144    shape: &'a [usize],
145    strides: &'a [isize],
146    len: usize,
147    storage: &'a mut DenseStorage,
148}
149
150impl<'a> DenseViewMut<'a> {
151    /// Private constructor used by `Tensor::as_dense_mut()`.
152    #[inline]
153    pub(crate) fn new(
154        dt: DatumType,
155        shape: &'a [usize],
156        strides: &'a [isize],
157        len: usize,
158        storage: &'a mut DenseStorage,
159    ) -> Self {
160        DenseViewMut { dt, shape, strides, len, storage }
161    }
162
163    // -- Metadata --
164
165    #[inline]
166    pub fn datum_type(&self) -> DatumType {
167        self.dt
168    }
169
170    #[inline]
171    pub fn shape(&self) -> &[usize] {
172        self.shape
173    }
174
175    #[inline]
176    pub fn strides(&self) -> &[isize] {
177        self.strides
178    }
179
180    #[inline]
181    pub fn rank(&self) -> usize {
182        self.shape.len()
183    }
184
185    #[inline]
186    pub fn len(&self) -> usize {
187        self.len
188    }
189
190    // -- Read access (same as DenseView, self.storage reborrows as &DenseStorage) --
191
192    #[inline]
193    pub fn as_bytes(&self) -> &[u8] {
194        self.storage.as_bytes()
195    }
196
197    #[inline]
198    pub fn layout(&self) -> &Layout {
199        self.storage.layout()
200    }
201
202    #[inline]
203    pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
204        check_for_access::<D>(self.dt)?;
205        unsafe { Ok(self.as_ptr_unchecked()) }
206    }
207
208    #[inline]
209    pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
210        self.storage.as_ptr() as *const D
211    }
212
213    #[inline]
214    pub fn as_slice<D: Datum>(&self) -> TractResult<&[D]> {
215        check_for_access::<D>(self.dt)?;
216        unsafe { Ok(self.as_slice_unchecked()) }
217    }
218
219    #[inline]
220    pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &[D] {
221        if self.storage.is_empty() {
222            &[]
223        } else {
224            unsafe { std::slice::from_raw_parts(self.as_ptr_unchecked(), self.len) }
225        }
226    }
227
228    #[inline]
229    pub fn to_scalar<D: Datum>(&self) -> TractResult<&D> {
230        check_for_access::<D>(self.dt)?;
231        unsafe { Ok(self.to_scalar_unchecked()) }
232    }
233
234    #[inline]
235    pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &D {
236        unsafe { &*(self.storage.as_ptr() as *const D) }
237    }
238
239    #[inline]
240    pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'_, D>> {
241        check_for_access::<D>(self.dt)?;
242        unsafe { Ok(self.to_array_view_unchecked()) }
243    }
244
245    #[inline]
246    pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'_, D> {
247        if self.len != 0 {
248            unsafe { ArrayViewD::from_shape_ptr(self.shape, self.storage.as_ptr() as *const D) }
249        } else {
250            ArrayViewD::from_shape(self.shape, &[]).unwrap()
251        }
252    }
253
254    // -- Mutable access --
255
256    #[inline]
257    pub fn as_bytes_mut(&mut self) -> &mut [u8] {
258        self.storage.as_bytes_mut()
259    }
260
261    #[inline]
262    pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
263        check_for_access::<D>(self.dt)?;
264        unsafe { Ok(self.as_ptr_mut_unchecked()) }
265    }
266
267    #[inline]
268    pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
269        self.storage.as_mut_ptr() as *mut D
270    }
271
272    #[inline]
273    pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
274        check_for_access::<D>(self.dt)?;
275        unsafe { Ok(self.as_slice_mut_unchecked()) }
276    }
277
278    #[inline]
279    pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
280        if self.storage.is_empty() {
281            &mut []
282        } else {
283            let len = self.len;
284            unsafe { std::slice::from_raw_parts_mut(self.as_ptr_mut_unchecked(), len) }
285        }
286    }
287
288    #[inline]
289    pub fn to_scalar_mut<D: Datum>(&mut self) -> TractResult<&mut D> {
290        check_for_access::<D>(self.dt)?;
291        unsafe { Ok(self.to_scalar_mut_unchecked()) }
292    }
293
294    #[inline]
295    pub unsafe fn to_scalar_mut_unchecked<D: Datum>(&mut self) -> &mut D {
296        unsafe { &mut *(self.storage.as_mut_ptr() as *mut D) }
297    }
298
299    #[inline]
300    pub fn to_array_view_mut<D: Datum>(&mut self) -> TractResult<ArrayViewMutD<'_, D>> {
301        check_for_access::<D>(self.dt)?;
302        unsafe { Ok(self.to_array_view_mut_unchecked()) }
303    }
304
305    #[inline]
306    pub unsafe fn to_array_view_mut_unchecked<D: Datum>(&mut self) -> ArrayViewMutD<'_, D> {
307        if self.len != 0 {
308            unsafe {
309                ArrayViewMutD::from_shape_ptr(self.shape, self.storage.as_mut_ptr() as *mut D)
310            }
311        } else {
312            ArrayViewMutD::from_shape(self.shape, &mut []).unwrap()
313        }
314    }
315}