tract_data/tensor/
view.rs

1use super::*;
2use crate::internal::*;
3
4#[derive(Clone, Debug)]
5enum Indexing<'a> {
6    Prefix(usize),
7    Custom { shape: &'a [usize], strides: &'a [isize] },
8}
9
10#[derive(Clone, Debug)]
11pub struct TensorView<'a> {
12    pub tensor: &'a Tensor,
13    offset_bytes: isize,
14    indexing: Indexing<'a>,
15}
16
17impl<'a> TensorView<'a> {
18    pub unsafe fn from_bytes(
19        tensor: &'a Tensor,
20        offset_bytes: isize,
21        shape: &'a [usize],
22        strides: &'a [isize],
23    ) -> TensorView<'a> {
24        TensorView { tensor, offset_bytes, indexing: Indexing::Custom { shape, strides } }
25    }
26
27    pub fn offsetting(tensor: &'a Tensor, coords: &[usize]) -> TractResult<TensorView<'a>> {
28        ensure!(
29            coords.len() == tensor.rank() && coords.iter().zip(tensor.shape()).all(|(p, d)| p < d),
30            "Invalid coords {:?} for shape {:?}",
31            coords,
32            tensor.shape()
33        );
34        unsafe { Ok(Self::offsetting_unchecked(tensor, coords)) }
35    }
36
37    pub unsafe fn offsetting_unchecked(tensor: &'a Tensor, coords: &[usize]) -> TensorView<'a> {
38        let offset_bytes =
39            coords.iter().zip(tensor.strides()).map(|(a, b)| *a as isize * b).sum::<isize>()
40                * tensor.datum_type().size_of() as isize;
41        TensorView {
42            tensor,
43            offset_bytes,
44            indexing: Indexing::Custom { shape: &tensor.shape, strides: &tensor.strides },
45        }
46    }
47
48    pub fn at_prefix(tensor: &'a Tensor, prefix: &[usize]) -> TractResult<TensorView<'a>> {
49        ensure!(
50            prefix.len() <= tensor.rank() && prefix.iter().zip(tensor.shape()).all(|(p, d)| p < d),
51            "Invalid prefix {:?} for shape {:?}",
52            prefix,
53            tensor.shape()
54        );
55        unsafe { Ok(Self::at_prefix_unchecked(tensor, prefix)) }
56    }
57
58    pub unsafe fn at_prefix_unchecked(tensor: &'a Tensor, prefix: &[usize]) -> TensorView<'a> {
59        let offset_bytes =
60            prefix.iter().zip(tensor.strides()).map(|(a, b)| *a as isize * b).sum::<isize>()
61                * tensor.datum_type().size_of() as isize;
62        TensorView { tensor, offset_bytes, indexing: Indexing::Prefix(prefix.len()) }
63    }
64
65    #[inline]
66    pub unsafe fn view(tensor: &'a Tensor) -> TensorView<'a> {
67        TensorView { tensor, offset_bytes: 0, indexing: Indexing::Prefix(0) }
68    }
69
70    #[inline]
71    pub fn datum_type(&self) -> DatumType {
72        self.tensor.datum_type()
73    }
74
75    #[inline]
76    pub fn shape(&self) -> &[usize] {
77        match &self.indexing {
78            Indexing::Prefix(i) => &self.tensor.shape()[*i..],
79            Indexing::Custom { shape, .. } => shape,
80        }
81    }
82
83    #[inline]
84    pub fn strides(&self) -> &[isize] {
85        match &self.indexing {
86            Indexing::Prefix(i) => &self.tensor.strides()[*i..],
87            Indexing::Custom { strides, .. } => strides,
88        }
89    }
90
91    #[inline]
92    #[allow(clippy::len_without_is_empty)]
93    pub fn len(&self) -> usize {
94        match &self.indexing {
95            Indexing::Prefix(i) => {
96                if *i == 0 {
97                    self.tensor.len()
98                } else {
99                    self.tensor.strides[*i - 1] as usize
100                }
101            }
102            Indexing::Custom { shape, .. } => shape.iter().product(),
103        }
104    }
105
106    #[inline]
107    #[allow(clippy::len_without_is_empty)]
108    pub fn valid_bytes(&self) -> usize {
109        self.tensor.data.layout().size() - self.offset_bytes as usize
110    }
111
112    #[inline]
113    pub fn rank(&self) -> usize {
114        match &self.indexing {
115            Indexing::Prefix(i) => self.tensor.rank() - i,
116            Indexing::Custom { shape, .. } => shape.len(),
117        }
118    }
119
120    fn check_dt<D: Datum>(&self) -> TractResult<()> {
121        self.tensor.check_for_access::<D>()
122    }
123
124    fn check_coords(&self, coords: &[usize]) -> TractResult<()> {
125        ensure!(
126            coords.len() == self.rank()
127                && coords.iter().zip(self.shape()).all(|(&x, &dim)| x < dim),
128            "Can't access coordinates {:?} of TensorView of shape {:?}",
129            coords,
130            self.shape(),
131        );
132        Ok(())
133    }
134
135    /// Access the data as a pointer.
136    #[inline]
137    pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
138        self.check_dt::<D>()?;
139        Ok(unsafe { self.as_ptr_unchecked() })
140    }
141
142    /// Access the data as a pointer.
143    #[inline]
144    pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
145        unsafe { self.tensor.as_ptr_unchecked::<u8>().offset(self.offset_bytes) as *const D }
146    }
147
148    /// Access the data as a pointer.
149    #[inline]
150    pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
151        unsafe { self.as_ptr_unchecked::<D>() as *mut D }
152    }
153
154    /// Access the data as a mutable pointer.
155    #[inline]
156    pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
157        Ok(self.as_ptr::<D>()? as *mut D)
158    }
159
160    /// Access the data as a slice.
161    #[inline]
162    pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &'a [D] {
163        unsafe { std::slice::from_raw_parts::<D>(self.as_ptr_unchecked(), self.len()) }
164    }
165
166    /// Access the data as a slice.
167    #[inline]
168    pub fn as_slice<D: Datum>(&self) -> TractResult<&'a [D]> {
169        self.check_dt::<D>()?;
170        unsafe { Ok(self.as_slice_unchecked()) }
171    }
172
173    /// Access the data as a mutable slice.
174    #[inline]
175    pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
176        unsafe { std::slice::from_raw_parts_mut::<D>(self.as_ptr_mut_unchecked(), self.len()) }
177    }
178
179    /// Access the data as a mutable slice.
180    #[inline]
181    pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
182        self.check_dt::<D>()?;
183        unsafe { Ok(self.as_slice_mut_unchecked()) }
184    }
185
186    #[inline]
187    pub unsafe fn offset_bytes(&mut self, offset: isize) {
188        self.offset_bytes += offset
189    }
190
191    #[inline]
192    pub unsafe fn offset_axis_unchecked(&mut self, axis: usize, pos: isize) {
193        let stride = self.strides()[axis] * self.datum_type().size_of() as isize;
194        unsafe { self.offset_bytes(stride * pos) }
195    }
196
197    #[inline]
198    pub unsafe fn offset_axis(&mut self, axis: usize, pos: isize) {
199        let stride = self.strides()[axis] * self.datum_type().size_of() as isize;
200        unsafe { self.offset_bytes(stride * pos) }
201    }
202
203    #[inline]
204    fn offset_for_coords(&self, coords: &[usize]) -> isize {
205        self.strides().iter().zip(coords.as_ref()).map(|(s, c)| *s * *c as isize).sum::<isize>()
206    }
207
208    #[inline]
209    pub unsafe fn at_unchecked<T: Datum>(&self, coords: impl AsRef<[usize]>) -> &T {
210        unsafe {
211            self.as_ptr_unchecked::<T>()
212                .offset(self.offset_for_coords(coords.as_ref()))
213                .as_ref()
214                .unwrap()
215        }
216    }
217
218    #[inline]
219    pub unsafe fn at_mut_unchecked<T: Datum>(&mut self, coords: impl AsRef<[usize]>) -> &mut T {
220        unsafe {
221            self.as_ptr_mut_unchecked::<T>()
222                .offset(self.offset_for_coords(coords.as_ref()))
223                .as_mut()
224                .unwrap()
225        }
226    }
227
228    #[inline]
229    pub fn at<T: Datum>(&self, coords: impl AsRef<[usize]>) -> TractResult<&T> {
230        self.check_dt::<T>()?;
231        let coords = coords.as_ref();
232        self.check_coords(coords)?;
233        unsafe { Ok(self.at_unchecked(coords)) }
234    }
235
236    #[inline]
237    pub fn at_mut<T: Datum>(&mut self, coords: impl AsRef<[usize]>) -> TractResult<&mut T> {
238        self.check_dt::<T>()?;
239        let coords = coords.as_ref();
240        self.check_coords(coords)?;
241        unsafe { Ok(self.at_mut_unchecked(coords)) }
242    }
243
244    /*
245      pub unsafe fn reshaped(&self, shape: impl AsRef<[usize]>) -> TensorView<'a> {
246      let shape = shape.as_ref();
247      let mut strides: TVec<isize> = shape
248      .iter()
249      .rev()
250      .scan(1, |state, d| {
251      let old = *state;
252    *state = *state * d;
253    Some(old as isize)
254    })
255    .collect();
256    strides.reverse();
257    TensorView { shape: shape.into(), strides, ..*self }
258    }
259    */
260}
261
262#[cfg(test)]
263mod test {
264    use super::TensorView;
265    use crate::prelude::Tensor;
266
267    #[test]
268    fn test_at_prefix() {
269        let a = Tensor::from_shape(&[2, 2], &[1, 2, 3, 4]).unwrap();
270        let a_view = TensorView::at_prefix(&a, &[1]).unwrap();
271        assert_eq!(a_view.shape(), &[2]);
272        assert_eq!(a_view.as_slice::<i32>().unwrap(), &[3, 4]);
273    }
274}