tract_data/tensor/
view.rs

1use super::*;
2use crate::internal::*;
3
4#[derive(Clone, Debug)]
5enum Indexing<'a> {
6    Prefix(usize),
7    Custom { shape: &'a [usize], strides: &'a [isize] },
8}
9
10#[derive(Clone, Debug)]
11pub struct TensorView<'a> {
12    pub tensor: &'a Tensor,
13    offset_bytes: isize,
14    indexing: Indexing<'a>,
15}
16
17impl<'a> TensorView<'a> {
18    pub unsafe fn from_bytes(
19        tensor: &'a Tensor,
20        offset_bytes: isize,
21        shape: &'a [usize],
22        strides: &'a [isize],
23    ) -> TensorView<'a> {
24        TensorView {
25            tensor,
26            offset_bytes,
27            indexing: Indexing::Custom { shape, strides },
28        }
29    }
30
31    pub fn offsetting(tensor: &'a Tensor, coords: &[usize]) -> TractResult<TensorView<'a>> {
32        ensure!(
33            coords.len() == tensor.rank() && coords.iter().zip(tensor.shape()).all(|(p, d)| p < d),
34            "Invalid coords {:?} for shape {:?}",
35            coords,
36            tensor.shape()
37        );
38        unsafe { Ok(Self::offsetting_unchecked(tensor, coords)) }
39    }
40
41    pub unsafe fn offsetting_unchecked(tensor: &'a Tensor, coords: &[usize]) -> TensorView<'a> {
42        let offset_bytes =
43            coords.iter().zip(tensor.strides()).map(|(a, b)| *a as isize * b).sum::<isize>()
44                * tensor.datum_type().size_of() as isize;
45        TensorView {
46            tensor,
47            offset_bytes,
48            indexing: Indexing::Custom {
49                shape: &tensor.shape,
50                strides: &tensor.strides,
51            },
52        }
53    }
54
55    pub fn at_prefix(tensor: &'a Tensor, prefix: &[usize]) -> TractResult<TensorView<'a>> {
56        ensure!(
57            prefix.len() <= tensor.rank() && prefix.iter().zip(tensor.shape()).all(|(p, d)| p < d),
58            "Invalid prefix {:?} for shape {:?}",
59            prefix,
60            tensor.shape()
61        );
62        unsafe { Ok(Self::at_prefix_unchecked(tensor, prefix)) }
63    }
64
65    pub unsafe fn at_prefix_unchecked(tensor: &'a Tensor, prefix: &[usize]) -> TensorView<'a> {
66        let offset_bytes =
67            prefix.iter().zip(tensor.strides()).map(|(a, b)| *a as isize * b).sum::<isize>()
68                * tensor.datum_type().size_of() as isize;
69        TensorView { tensor, offset_bytes, indexing: Indexing::Prefix(prefix.len()) }
70    }
71
72    #[inline]
73    pub unsafe fn view(tensor: &'a Tensor) -> TensorView<'a> {
74        TensorView { tensor, offset_bytes: 0, indexing: Indexing::Prefix(0) }
75    }
76
77    #[inline]
78    pub fn datum_type(&self) -> DatumType {
79        self.tensor.datum_type()
80    }
81
82    #[inline]
83    pub fn shape(&self) -> &[usize] {
84        match &self.indexing {
85            Indexing::Prefix(i) => &self.tensor.shape()[*i..],
86            Indexing::Custom { shape, .. } => shape,
87        }
88    }
89
90    #[inline]
91    pub fn strides(&self) -> &[isize] {
92        match &self.indexing {
93            Indexing::Prefix(i) => &self.tensor.strides()[*i..],
94            Indexing::Custom { strides, .. } => strides,
95        }
96    }
97
98    #[inline]
99    #[allow(clippy::len_without_is_empty)]
100    pub fn len(&self) -> usize {
101        match &self.indexing {
102            Indexing::Prefix(i) => {
103                if *i == 0 {
104                    self.tensor.len()
105                } else {
106                    self.tensor.strides[*i - 1] as usize
107                }
108            }
109            Indexing::Custom { shape, .. } => shape.iter().product(),
110        }
111    }
112
113    #[inline]
114    #[allow(clippy::len_without_is_empty)]
115    pub fn valid_bytes(&self) -> usize {
116        self.tensor.data.layout().size() - self.offset_bytes as usize
117    }
118
119    #[inline]
120    pub fn rank(&self) -> usize {
121        match &self.indexing {
122            Indexing::Prefix(i) => self.tensor.rank() - i,
123            Indexing::Custom { shape, .. } => shape.len(),
124        }
125    }
126
127    fn check_dt<D: Datum>(&self) -> TractResult<()> {
128        self.tensor.check_for_access::<D>()
129    }
130
131    fn check_coords(&self, coords: &[usize]) -> TractResult<()> {
132        ensure!(
133            coords.len() == self.rank()
134                && coords.iter().zip(self.shape()).all(|(&x, &dim)| x < dim),
135            "Can't access coordinates {:?} of TensorView of shape {:?}",
136            coords,
137            self.shape(),
138        );
139        Ok(())
140    }
141
142    /// Access the data as a pointer.
143    #[inline]
144    pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
145        self.check_dt::<D>()?;
146        Ok(unsafe { self.as_ptr_unchecked() })
147    }
148
149    /// Access the data as a pointer.
150    #[inline]
151    pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
152        self.tensor.as_ptr_unchecked::<u8>().offset(self.offset_bytes) as *const D
153    }
154
155    /// Access the data as a pointer.
156    #[inline]
157    pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
158        self.as_ptr_unchecked::<D>() as *mut D
159    }
160
161    /// Access the data as a mutable pointer.
162    #[inline]
163    pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
164        Ok(self.as_ptr::<D>()? as *mut D)
165    }
166
167    /// Access the data as a slice.
168    #[inline]
169    pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &'a [D] {
170        std::slice::from_raw_parts::<D>(self.as_ptr_unchecked(), self.len())
171    }
172
173    /// Access the data as a slice.
174    #[inline]
175    pub fn as_slice<D: Datum>(&self) -> TractResult<&'a [D]> {
176        self.check_dt::<D>()?;
177        unsafe { Ok(self.as_slice_unchecked()) }
178    }
179
180    /// Access the data as a mutable slice.
181    #[inline]
182    pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
183        std::slice::from_raw_parts_mut::<D>(self.as_ptr_mut_unchecked(), self.len())
184    }
185
186    /// Access the data as a mutable slice.
187    #[inline]
188    pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
189        self.check_dt::<D>()?;
190        unsafe { Ok(self.as_slice_mut_unchecked()) }
191    }
192
193    #[inline]
194    pub unsafe fn offset_bytes(&mut self, offset: isize) {
195        self.offset_bytes += offset
196    }
197
198    #[inline]
199    pub unsafe fn offset_axis_unchecked(&mut self, axis: usize, pos: isize) {
200        let stride = self.strides()[axis] * self.datum_type().size_of() as isize;
201        self.offset_bytes(stride * pos)
202    }
203
204    #[inline]
205    pub unsafe fn offset_axis(&mut self, axis: usize, pos: isize) {
206        let stride = self.strides()[axis] * self.datum_type().size_of() as isize;
207        self.offset_bytes(stride * pos)
208    }
209
210    #[inline]
211    fn offset_for_coords(&self, coords: &[usize]) -> isize {
212        self.strides().iter().zip(coords.as_ref()).map(|(s, c)| *s * *c as isize).sum::<isize>()
213    }
214
215    #[inline]
216    pub unsafe fn at_unchecked<T: Datum>(&self, coords: impl AsRef<[usize]>) -> &T {
217        self.as_ptr_unchecked::<T>()
218            .offset(self.offset_for_coords(coords.as_ref()))
219            .as_ref()
220            .unwrap()
221    }
222
223    #[inline]
224    pub unsafe fn at_mut_unchecked<T: Datum>(&mut self, coords: impl AsRef<[usize]>) -> &mut T {
225        self.as_ptr_mut_unchecked::<T>()
226            .offset(self.offset_for_coords(coords.as_ref()))
227            .as_mut()
228            .unwrap()
229    }
230
231    #[inline]
232    pub fn at<T: Datum>(&self, coords: impl AsRef<[usize]>) -> TractResult<&T> {
233        self.check_dt::<T>()?;
234        let coords = coords.as_ref();
235        self.check_coords(coords)?;
236        unsafe { Ok(self.at_unchecked(coords)) }
237    }
238
239    #[inline]
240    pub fn at_mut<T: Datum>(&mut self, coords: impl AsRef<[usize]>) -> TractResult<&mut T> {
241        self.check_dt::<T>()?;
242        let coords = coords.as_ref();
243        self.check_coords(coords)?;
244        unsafe { Ok(self.at_mut_unchecked(coords)) }
245    }
246
247    /*
248      pub unsafe fn reshaped(&self, shape: impl AsRef<[usize]>) -> TensorView<'a> {
249      let shape = shape.as_ref();
250      let mut strides: TVec<isize> = shape
251      .iter()
252      .rev()
253      .scan(1, |state, d| {
254      let old = *state;
255    *state = *state * d;
256    Some(old as isize)
257    })
258    .collect();
259    strides.reverse();
260    TensorView { shape: shape.into(), strides, ..*self }
261    }
262    */
263}
264
265#[cfg(test)]
266mod test {
267    use crate::prelude::Tensor;
268    use super::TensorView;
269
270    #[test]
271    fn test_at_prefix() {
272        let a = Tensor::from_shape(&[2, 2], &[1, 2, 3, 4]).unwrap();
273        let a_view = TensorView::at_prefix(&a, &[1]).unwrap();
274        assert_eq!(a_view.shape(), &[2]);
275        assert_eq!(a_view.as_slice::<i32>().unwrap(), &[3, 4]);
276
277
278    }
279}