tract_data/tensor/
view.rs1use super::*;
2use crate::internal::*;
3
4#[derive(Clone, Debug)]
5enum Indexing<'a> {
6 Prefix(usize),
7 Custom { shape: &'a [usize], strides: &'a [isize] },
8}
9
10#[derive(Clone, Debug)]
11pub struct TensorView<'a> {
12 pub tensor: &'a Tensor,
13 offset_bytes: isize,
14 indexing: Indexing<'a>,
15}
16
17impl<'a> TensorView<'a> {
18 pub unsafe fn from_bytes(
19 tensor: &'a Tensor,
20 offset_bytes: isize,
21 shape: &'a [usize],
22 strides: &'a [isize],
23 ) -> TensorView<'a> {
24 TensorView { tensor, offset_bytes, indexing: Indexing::Custom { shape, strides } }
25 }
26
27 pub fn offsetting(tensor: &'a Tensor, coords: &[usize]) -> TractResult<TensorView<'a>> {
28 ensure!(
29 coords.len() == tensor.rank() && coords.iter().zip(tensor.shape()).all(|(p, d)| p < d),
30 "Invalid coords {:?} for shape {:?}",
31 coords,
32 tensor.shape()
33 );
34 unsafe { Ok(Self::offsetting_unchecked(tensor, coords)) }
35 }
36
37 pub unsafe fn offsetting_unchecked(tensor: &'a Tensor, coords: &[usize]) -> TensorView<'a> {
38 let offset_bytes =
39 coords.iter().zip(tensor.strides()).map(|(a, b)| *a as isize * b).sum::<isize>()
40 * tensor.datum_type().size_of() as isize;
41 TensorView {
42 tensor,
43 offset_bytes,
44 indexing: Indexing::Custom { shape: &tensor.shape, strides: &tensor.strides },
45 }
46 }
47
48 pub fn at_prefix(tensor: &'a Tensor, prefix: &[usize]) -> TractResult<TensorView<'a>> {
49 ensure!(
50 prefix.len() <= tensor.rank() && prefix.iter().zip(tensor.shape()).all(|(p, d)| p < d),
51 "Invalid prefix {:?} for shape {:?}",
52 prefix,
53 tensor.shape()
54 );
55 unsafe { Ok(Self::at_prefix_unchecked(tensor, prefix)) }
56 }
57
58 pub unsafe fn at_prefix_unchecked(tensor: &'a Tensor, prefix: &[usize]) -> TensorView<'a> {
59 let offset_bytes =
60 prefix.iter().zip(tensor.strides()).map(|(a, b)| *a as isize * b).sum::<isize>()
61 * tensor.datum_type().size_of() as isize;
62 TensorView { tensor, offset_bytes, indexing: Indexing::Prefix(prefix.len()) }
63 }
64
65 #[inline]
66 pub unsafe fn view(tensor: &'a Tensor) -> TensorView<'a> {
67 TensorView { tensor, offset_bytes: 0, indexing: Indexing::Prefix(0) }
68 }
69
70 #[inline]
71 pub fn datum_type(&self) -> DatumType {
72 self.tensor.datum_type()
73 }
74
75 #[inline]
76 pub fn shape(&self) -> &[usize] {
77 match &self.indexing {
78 Indexing::Prefix(i) => &self.tensor.shape()[*i..],
79 Indexing::Custom { shape, .. } => shape,
80 }
81 }
82
83 #[inline]
84 pub fn strides(&self) -> &[isize] {
85 match &self.indexing {
86 Indexing::Prefix(i) => &self.tensor.strides()[*i..],
87 Indexing::Custom { strides, .. } => strides,
88 }
89 }
90
91 #[inline]
92 #[allow(clippy::len_without_is_empty)]
93 pub fn len(&self) -> usize {
94 match &self.indexing {
95 Indexing::Prefix(i) => {
96 if *i == 0 {
97 self.tensor.len()
98 } else {
99 self.tensor.strides[*i - 1] as usize
100 }
101 }
102 Indexing::Custom { shape, .. } => shape.iter().product(),
103 }
104 }
105
106 #[inline]
107 #[allow(clippy::len_without_is_empty)]
108 pub fn valid_bytes(&self) -> usize {
109 self.tensor.data.layout().size() - self.offset_bytes as usize
110 }
111
112 #[inline]
113 pub fn rank(&self) -> usize {
114 match &self.indexing {
115 Indexing::Prefix(i) => self.tensor.rank() - i,
116 Indexing::Custom { shape, .. } => shape.len(),
117 }
118 }
119
120 fn check_dt<D: Datum>(&self) -> TractResult<()> {
121 self.tensor.check_for_access::<D>()
122 }
123
124 fn check_coords(&self, coords: &[usize]) -> TractResult<()> {
125 ensure!(
126 coords.len() == self.rank()
127 && coords.iter().zip(self.shape()).all(|(&x, &dim)| x < dim),
128 "Can't access coordinates {:?} of TensorView of shape {:?}",
129 coords,
130 self.shape(),
131 );
132 Ok(())
133 }
134
135 #[inline]
137 pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
138 self.check_dt::<D>()?;
139 Ok(unsafe { self.as_ptr_unchecked() })
140 }
141
142 #[inline]
144 pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
145 unsafe { self.tensor.as_ptr_unchecked::<u8>().offset(self.offset_bytes) as *const D }
146 }
147
148 #[inline]
150 pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
151 unsafe { self.as_ptr_unchecked::<D>() as *mut D }
152 }
153
154 #[inline]
156 pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
157 Ok(self.as_ptr::<D>()? as *mut D)
158 }
159
160 #[inline]
162 pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &'a [D] {
163 unsafe { std::slice::from_raw_parts::<D>(self.as_ptr_unchecked(), self.len()) }
164 }
165
166 #[inline]
168 pub fn as_slice<D: Datum>(&self) -> TractResult<&'a [D]> {
169 self.check_dt::<D>()?;
170 unsafe { Ok(self.as_slice_unchecked()) }
171 }
172
173 #[inline]
175 pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
176 unsafe { std::slice::from_raw_parts_mut::<D>(self.as_ptr_mut_unchecked(), self.len()) }
177 }
178
179 #[inline]
181 pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
182 self.check_dt::<D>()?;
183 unsafe { Ok(self.as_slice_mut_unchecked()) }
184 }
185
186 #[inline]
187 pub unsafe fn offset_bytes(&mut self, offset: isize) {
188 self.offset_bytes += offset
189 }
190
191 #[inline]
192 pub unsafe fn offset_axis_unchecked(&mut self, axis: usize, pos: isize) {
193 let stride = self.strides()[axis] * self.datum_type().size_of() as isize;
194 unsafe { self.offset_bytes(stride * pos) }
195 }
196
197 #[inline]
198 pub unsafe fn offset_axis(&mut self, axis: usize, pos: isize) {
199 let stride = self.strides()[axis] * self.datum_type().size_of() as isize;
200 unsafe { self.offset_bytes(stride * pos) }
201 }
202
203 #[inline]
204 fn offset_for_coords(&self, coords: &[usize]) -> isize {
205 self.strides().iter().zip(coords.as_ref()).map(|(s, c)| *s * *c as isize).sum::<isize>()
206 }
207
208 #[inline]
209 pub unsafe fn at_unchecked<T: Datum>(&self, coords: impl AsRef<[usize]>) -> &T {
210 unsafe {
211 self.as_ptr_unchecked::<T>()
212 .offset(self.offset_for_coords(coords.as_ref()))
213 .as_ref()
214 .unwrap()
215 }
216 }
217
218 #[inline]
219 pub unsafe fn at_mut_unchecked<T: Datum>(&mut self, coords: impl AsRef<[usize]>) -> &mut T {
220 unsafe {
221 self.as_ptr_mut_unchecked::<T>()
222 .offset(self.offset_for_coords(coords.as_ref()))
223 .as_mut()
224 .unwrap()
225 }
226 }
227
228 #[inline]
229 pub fn at<T: Datum>(&self, coords: impl AsRef<[usize]>) -> TractResult<&T> {
230 self.check_dt::<T>()?;
231 let coords = coords.as_ref();
232 self.check_coords(coords)?;
233 unsafe { Ok(self.at_unchecked(coords)) }
234 }
235
236 #[inline]
237 pub fn at_mut<T: Datum>(&mut self, coords: impl AsRef<[usize]>) -> TractResult<&mut T> {
238 self.check_dt::<T>()?;
239 let coords = coords.as_ref();
240 self.check_coords(coords)?;
241 unsafe { Ok(self.at_mut_unchecked(coords)) }
242 }
243
244 }
261
262#[cfg(test)]
263mod test {
264 use super::TensorView;
265 use crate::prelude::Tensor;
266
267 #[test]
268 fn test_at_prefix() {
269 let a = Tensor::from_shape(&[2, 2], &[1, 2, 3, 4]).unwrap();
270 let a_view = TensorView::at_prefix(&a, &[1]).unwrap();
271 assert_eq!(a_view.shape(), &[2]);
272 assert_eq!(a_view.as_slice::<i32>().unwrap(), &[3, 4]);
273 }
274}