tract_data/tensor/
view.rs1use super::*;
2use crate::internal::*;
3
4#[derive(Clone, Debug)]
5enum Indexing<'a> {
6 Prefix(usize),
7 Custom { shape: &'a [usize], strides: &'a [isize] },
8}
9
10#[derive(Clone, Debug)]
11pub struct TensorView<'a> {
12 pub tensor: &'a Tensor,
13 offset_bytes: isize,
14 indexing: Indexing<'a>,
15}
16
17impl<'a> TensorView<'a> {
18 pub unsafe fn from_bytes(
19 tensor: &'a Tensor,
20 offset_bytes: isize,
21 shape: &'a [usize],
22 strides: &'a [isize],
23 ) -> TensorView<'a> {
24 TensorView {
25 tensor,
26 offset_bytes,
27 indexing: Indexing::Custom { shape, strides },
28 }
29 }
30
31 pub fn offsetting(tensor: &'a Tensor, coords: &[usize]) -> TractResult<TensorView<'a>> {
32 ensure!(
33 coords.len() == tensor.rank() && coords.iter().zip(tensor.shape()).all(|(p, d)| p < d),
34 "Invalid coords {:?} for shape {:?}",
35 coords,
36 tensor.shape()
37 );
38 unsafe { Ok(Self::offsetting_unchecked(tensor, coords)) }
39 }
40
41 pub unsafe fn offsetting_unchecked(tensor: &'a Tensor, coords: &[usize]) -> TensorView<'a> {
42 let offset_bytes =
43 coords.iter().zip(tensor.strides()).map(|(a, b)| *a as isize * b).sum::<isize>()
44 * tensor.datum_type().size_of() as isize;
45 TensorView {
46 tensor,
47 offset_bytes,
48 indexing: Indexing::Custom {
49 shape: &tensor.shape,
50 strides: &tensor.strides,
51 },
52 }
53 }
54
55 pub fn at_prefix(tensor: &'a Tensor, prefix: &[usize]) -> TractResult<TensorView<'a>> {
56 ensure!(
57 prefix.len() <= tensor.rank() && prefix.iter().zip(tensor.shape()).all(|(p, d)| p < d),
58 "Invalid prefix {:?} for shape {:?}",
59 prefix,
60 tensor.shape()
61 );
62 unsafe { Ok(Self::at_prefix_unchecked(tensor, prefix)) }
63 }
64
65 pub unsafe fn at_prefix_unchecked(tensor: &'a Tensor, prefix: &[usize]) -> TensorView<'a> {
66 let offset_bytes =
67 prefix.iter().zip(tensor.strides()).map(|(a, b)| *a as isize * b).sum::<isize>()
68 * tensor.datum_type().size_of() as isize;
69 TensorView { tensor, offset_bytes, indexing: Indexing::Prefix(prefix.len()) }
70 }
71
72 #[inline]
73 pub unsafe fn view(tensor: &'a Tensor) -> TensorView<'a> {
74 TensorView { tensor, offset_bytes: 0, indexing: Indexing::Prefix(0) }
75 }
76
77 #[inline]
78 pub fn datum_type(&self) -> DatumType {
79 self.tensor.datum_type()
80 }
81
82 #[inline]
83 pub fn shape(&self) -> &[usize] {
84 match &self.indexing {
85 Indexing::Prefix(i) => &self.tensor.shape()[*i..],
86 Indexing::Custom { shape, .. } => shape,
87 }
88 }
89
90 #[inline]
91 pub fn strides(&self) -> &[isize] {
92 match &self.indexing {
93 Indexing::Prefix(i) => &self.tensor.strides()[*i..],
94 Indexing::Custom { strides, .. } => strides,
95 }
96 }
97
98 #[inline]
99 #[allow(clippy::len_without_is_empty)]
100 pub fn len(&self) -> usize {
101 match &self.indexing {
102 Indexing::Prefix(i) => {
103 if *i == 0 {
104 self.tensor.len()
105 } else {
106 self.tensor.strides[*i - 1] as usize
107 }
108 }
109 Indexing::Custom { shape, .. } => shape.iter().product(),
110 }
111 }
112
113 #[inline]
114 #[allow(clippy::len_without_is_empty)]
115 pub fn valid_bytes(&self) -> usize {
116 self.tensor.data.layout().size() - self.offset_bytes as usize
117 }
118
119 #[inline]
120 pub fn rank(&self) -> usize {
121 match &self.indexing {
122 Indexing::Prefix(i) => self.tensor.rank() - i,
123 Indexing::Custom { shape, .. } => shape.len(),
124 }
125 }
126
127 fn check_dt<D: Datum>(&self) -> TractResult<()> {
128 self.tensor.check_for_access::<D>()
129 }
130
131 fn check_coords(&self, coords: &[usize]) -> TractResult<()> {
132 ensure!(
133 coords.len() == self.rank()
134 && coords.iter().zip(self.shape()).all(|(&x, &dim)| x < dim),
135 "Can't access coordinates {:?} of TensorView of shape {:?}",
136 coords,
137 self.shape(),
138 );
139 Ok(())
140 }
141
142 #[inline]
144 pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
145 self.check_dt::<D>()?;
146 Ok(unsafe { self.as_ptr_unchecked() })
147 }
148
149 #[inline]
151 pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
152 self.tensor.as_ptr_unchecked::<u8>().offset(self.offset_bytes) as *const D
153 }
154
155 #[inline]
157 pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
158 self.as_ptr_unchecked::<D>() as *mut D
159 }
160
161 #[inline]
163 pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
164 Ok(self.as_ptr::<D>()? as *mut D)
165 }
166
167 #[inline]
169 pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &'a [D] {
170 std::slice::from_raw_parts::<D>(self.as_ptr_unchecked(), self.len())
171 }
172
173 #[inline]
175 pub fn as_slice<D: Datum>(&self) -> TractResult<&'a [D]> {
176 self.check_dt::<D>()?;
177 unsafe { Ok(self.as_slice_unchecked()) }
178 }
179
180 #[inline]
182 pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
183 std::slice::from_raw_parts_mut::<D>(self.as_ptr_mut_unchecked(), self.len())
184 }
185
186 #[inline]
188 pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
189 self.check_dt::<D>()?;
190 unsafe { Ok(self.as_slice_mut_unchecked()) }
191 }
192
193 #[inline]
194 pub unsafe fn offset_bytes(&mut self, offset: isize) {
195 self.offset_bytes += offset
196 }
197
198 #[inline]
199 pub unsafe fn offset_axis_unchecked(&mut self, axis: usize, pos: isize) {
200 let stride = self.strides()[axis] * self.datum_type().size_of() as isize;
201 self.offset_bytes(stride * pos)
202 }
203
204 #[inline]
205 pub unsafe fn offset_axis(&mut self, axis: usize, pos: isize) {
206 let stride = self.strides()[axis] * self.datum_type().size_of() as isize;
207 self.offset_bytes(stride * pos)
208 }
209
210 #[inline]
211 fn offset_for_coords(&self, coords: &[usize]) -> isize {
212 self.strides().iter().zip(coords.as_ref()).map(|(s, c)| *s * *c as isize).sum::<isize>()
213 }
214
215 #[inline]
216 pub unsafe fn at_unchecked<T: Datum>(&self, coords: impl AsRef<[usize]>) -> &T {
217 self.as_ptr_unchecked::<T>()
218 .offset(self.offset_for_coords(coords.as_ref()))
219 .as_ref()
220 .unwrap()
221 }
222
223 #[inline]
224 pub unsafe fn at_mut_unchecked<T: Datum>(&mut self, coords: impl AsRef<[usize]>) -> &mut T {
225 self.as_ptr_mut_unchecked::<T>()
226 .offset(self.offset_for_coords(coords.as_ref()))
227 .as_mut()
228 .unwrap()
229 }
230
231 #[inline]
232 pub fn at<T: Datum>(&self, coords: impl AsRef<[usize]>) -> TractResult<&T> {
233 self.check_dt::<T>()?;
234 let coords = coords.as_ref();
235 self.check_coords(coords)?;
236 unsafe { Ok(self.at_unchecked(coords)) }
237 }
238
239 #[inline]
240 pub fn at_mut<T: Datum>(&mut self, coords: impl AsRef<[usize]>) -> TractResult<&mut T> {
241 self.check_dt::<T>()?;
242 let coords = coords.as_ref();
243 self.check_coords(coords)?;
244 unsafe { Ok(self.at_mut_unchecked(coords)) }
245 }
246
247 }
264
265#[cfg(test)]
266mod test {
267 use crate::prelude::Tensor;
268 use super::TensorView;
269
270 #[test]
271 fn test_at_prefix() {
272 let a = Tensor::from_shape(&[2, 2], &[1, 2, 3, 4]).unwrap();
273 let a_view = TensorView::at_prefix(&a, &[1]).unwrap();
274 assert_eq!(a_view.shape(), &[2]);
275 assert_eq!(a_view.as_slice::<i32>().unwrap(), &[3, 4]);
276
277
278 }
279}