1use crate::internal::*;
2use tract_core::ops::array::StridedSlice;
3use tract_itertools::Itertools;
4
5impl InferenceRulesOp for StridedSlice {
6 fn rules<'r, 'p: 'r, 's: 'r>(
7 &'s self,
8 s: &mut Solver<'r>,
9 inputs: &'p [TensorProxy],
10 outputs: &'p [TensorProxy],
11 ) -> InferenceResult {
12 check_input_arity(
13 inputs,
14 3 + self.optional_axes_input.is_some() as usize
15 + self.optional_steps_input.is_some() as usize,
16 )?;
17 check_output_arity(outputs, 1)?;
18 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
19 s.equals(&inputs[1].rank, 1)?;
20 s.equals(&inputs[2].rank, 1)?;
21 s.equals(&inputs[1].shape[0], &inputs[2].shape[0])?;
22 s.equals(
23 &outputs[0].rank,
24 inputs[0].rank.bex() - self.shrink_axis_mask.count_ones() as i64,
25 )?;
26 if let Some(axis) = self.optional_axes_input {
27 s.equals(&inputs[1].shape, &inputs[axis].shape)?;
28 };
29 if let Some(step) = self.optional_steps_input {
30 s.equals(&inputs[1].shape, &inputs[step].shape)?;
31 };
32 if let Some(axes_input) = self.optional_axes_input {
33 s.given(&inputs[axes_input].value, move |s, axes| {
34 let axes = axes.cast_to::<i64>()?.into_owned();
35 s.given(&outputs[0].rank, move |s, orank| {
36 let axes = axes
37 .as_slice::<i64>()?
38 .iter()
39 .map(|a| if *a >= 0 { *a } else { *a + orank } as usize)
40 .collect_vec();
41 let mut iaxis = 0;
42 for oaxis in 0..orank as usize {
43 while self.shrink_axis_mask & (1 << iaxis) != 0 {
44 iaxis += 1;
45 }
46 if !axes.contains(&iaxis) {
47 s.equals(&inputs[0].shape[iaxis], &outputs[0].shape[oaxis])?;
48 }
49 iaxis += 1;
50 }
51 Ok(())
52 })
53 })?;
54 }
55 s.given(&inputs[0].shape, move |s, input_shape| {
56 s.given_all(inputs[1..].iter().map(|i| &i.value), move |s, params| {
57 let begin = ¶ms[0];
58 let end = ¶ms[1];
59 let strides = if let Some(i) = self.optional_steps_input {
60 let t = params[i - 1].cast_to::<i32>()?;
61 t.as_slice::<i32>()?.to_vec()
62 } else {
63 vec![1; input_shape.len()]
64 };
65 let axes: TVec<usize> = if let Some(i) = self.optional_axes_input {
66 let axes = params[i - 1].cast_to::<i32>()?;
67 axes.as_slice::<i32>()?
68 .iter()
69 .map(|&i| if i < 0 { input_shape.len() as i32 + i } else { i } as usize)
70 .collect()
71 } else {
72 (0..input_shape.len()).collect()
73 };
74 let mut output_shape = input_shape.clone();
75 let mut shrink = vec![];
76 for (ix, axis) in axes.into_iter().enumerate() {
77 let preped =
78 self.prepare_one_dim(ix, &input_shape[axis], begin, end, &strides)?;
79 output_shape[axis] = preped.soft_len()?;
80 if preped.shrink {
81 shrink.push(axis);
82 }
83 }
84 for shrink in shrink.iter().sorted().rev() {
85 output_shape.remove(*shrink);
86 }
87 s.equals(&outputs[0].shape, output_shape)
88 })
89 })
90 }
91
92 to_typed!();
93 as_op!();
94}
95
96#[cfg(test)]
97mod tests {
98 #![allow(non_snake_case)]
99 use super::*;
100 use tract_core::ops::array::strided_slice::Dim;
101 use tract_ndarray::{arr1, arr2, arr3};
102
103 pub fn strided_slice(begin_mask: i64, end_mask: i64, shrink_axis_mask: i64) -> StridedSlice {
104 StridedSlice {
105 begin_mask,
106 end_mask,
107 shrink_axis_mask,
108 optional_axes_input: None,
109 optional_steps_input: Some(3),
110 }
111 }
112
113 fn eval<I, B, E, S>(op: StridedSlice, input: I, begin: B, end: E, strides: S) -> Tensor
114 where
115 I: Into<Tensor>,
116 B: Into<Tensor>,
117 E: Into<Tensor>,
118 S: Into<Tensor>,
119 {
120 op.eval(tvec![
121 input.into().into(),
122 begin.into().into(),
123 end.into().into(),
124 strides.into().into(),
125 ])
126 .unwrap()
127 .pop()
128 .unwrap()
129 .into_tensor()
130 }
131
132 #[test]
134 fn eval_1() {
135 assert_eq!(
136 eval(
137 strided_slice(0, 0, 0),
138 arr3(&[[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]],]),
139 tensor1(&[1, 0, 0]),
140 tensor1(&[2, 1, 3]),
141 tensor1(&[1, 1, 1])
142 ),
143 Tensor::from(arr3(&[[[3, 3, 3]]])),
144 );
145 }
146
147 #[test]
148 fn eval_2() {
149 assert_eq!(
150 eval(
151 strided_slice(0, 0, 0),
152 arr3(&[[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]],]),
153 tensor1(&[1, 0, 0]),
154 tensor1(&[2, 2, 3]),
155 tensor1(&[1, 1, 1])
156 ),
157 Tensor::from(arr3(&[[[3, 3, 3], [4, 4, 4]]])),
158 );
159 }
160
161 #[test]
162 fn eval_3_negative_stride() {
163 assert_eq!(
164 eval(
165 strided_slice(0, 0, 0),
166 arr3(&[[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]],]),
167 tensor1(&[1, -1, 0]),
168 tensor1(&[2, -3, 3]),
169 tensor1(&[1, -1, 1])
170 ),
171 Tensor::from(arr3(&[[[4, 4, 4], [3, 3, 3]]])),
172 );
173 }
174
175 #[test]
176 fn eval_3_bis() {
177 assert_eq!(
178 eval(
179 strided_slice(0, 0, 0),
180 arr1(&[0, 1]),
181 tensor1(&[-1]),
182 tensor1(&[-3]),
183 tensor1(&[-1])
184 ),
185 Tensor::from(arr1(&[1, 0]))
186 );
187 }
188
189 #[test]
190 fn eval_4() {
191 assert_eq!(
192 eval(
193 strided_slice(0, 0, 0),
194 tensor3(&[[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]],]),
195 tensor1(&[1, 0, 0]),
196 tensor1(&[2, 2, 4]),
197 tensor1(&[1, 1, 2])
198 ),
199 tensor3(&[[[3, 3], [4, 4]]]),
200 );
201 }
202
203 #[test]
204 fn eval_5() {
205 assert_eq!(
206 eval(
207 strided_slice(0, 0, 0),
208 tensor1(&[0, 0]),
209 tensor1(&[0]),
210 tensor1(&[-1]),
211 tensor1(&[1])
212 ),
213 tensor1(&[0])
214 )
215 }
216
217 #[test]
218 fn eval_6() {
219 assert_eq!(
220 eval(
221 strided_slice(0, 0, 0),
222 tensor2(&[[1, 0, 0, 0], [3, 0, 0, 0], [0, 0, 0, 0]]),
223 tensor1(&[-3, -4]),
224 tensor1(&[-1, -1]),
225 tensor1(&[1, 2])
226 ),
227 tensor2(&[[1, 0], [3, 0]])
228 )
229 }
230
231 #[test]
232 fn eval_7() {
233 assert_eq!(
234 eval(
235 strided_slice(0, 0, 0),
236 tensor2(&[[0, 6], [0, 0]]),
237 tensor1(&[0]),
238 tensor1(&[2]),
239 tensor1(&[1])
240 ),
241 tensor2(&[[0, 6], [0, 0]])
242 )
243 }
244
245 #[test]
246 fn eval_begin_mask_1() {
247 let mut op = strided_slice(0, 0, 0);
248 op.begin_mask = 1;
249 assert_eq!(
250 eval(op, tensor1(&[0, 1]), tensor1(&[1]), tensor1(&[1]), tensor1(&[1])),
251 tensor1(&[0])
252 )
253 }
254
255 #[test]
256 fn eval_shrink_1() {
257 let mut op = strided_slice(0, 0, 0);
258 op.shrink_axis_mask = 1;
259 assert_eq!(
260 eval(op, arr2(&[[0]]), tensor1(&[0, 0]), tensor1(&[0, 0]), tensor1(&[1, 1])),
261 tensor1::<i32>(&[])
262 )
263 }
264
265 #[test]
266 fn eval_shrink_to_scalar() {
267 let mut op = strided_slice(0, 0, 0);
268 op.shrink_axis_mask = 1;
269 assert_eq!(
270 eval(op, tensor1(&[0]), tensor1(&[0]), tensor1(&[0]), tensor1(&[1])),
271 tensor0::<i32>(0)
272 )
273 }
274
275 #[test]
276 fn inference_1() {
277 let mut op = strided_slice(5, 7, 0);
278 let input = InferenceFact::default().with_datum_type(DatumType::F32);
279 let begin = InferenceFact::from(tensor1(&[0i32, 2, 0]));
280 let end = InferenceFact::from(tensor1(&[0i32, 0, 0]));
281 let strides = InferenceFact::from(tensor1(&[1i32, 1, 1]));
282 let any = InferenceFact::default();
283
284 let (input_facts, output_facts, _) =
285 op.infer_facts(tvec![&input, &begin, &end, &strides], tvec![&any], tvec!()).unwrap();
286 assert_eq!(
287 input_facts,
288 tvec![
289 InferenceFact::default()
290 .with_datum_type(DatumType::F32)
291 .with_shape(shapefactoid![..]),
292 begin,
293 end,
294 strides,
295 ]
296 );
297 assert_eq!(
298 output_facts,
299 tvec![InferenceFact::default()
300 .with_datum_type(DatumType::F32)
301 .with_shape(shapefactoid![..]),]
302 );
303 }
304
305 #[test]
306 fn inference_2() {
307 let mut op = strided_slice(1, 1, 2);
308 let input = InferenceFact::default().with_datum_type(DatumType::F32);
309 let begin = InferenceFact::from(tensor1(&[0i32, 0]));
310 let end = InferenceFact::from(tensor1(&[0i32, 1]));
311 let strides = InferenceFact::from(tensor1(&[1i32, 1]));
312 let any = InferenceFact::default();
313
314 let (input_facts, output_facts, _) =
315 op.infer_facts(tvec![&input, &begin, &end, &strides], tvec![&any], tvec!()).unwrap();
316 assert_eq!(
317 input_facts,
318 tvec![
319 InferenceFact::default()
320 .with_datum_type(DatumType::F32)
321 .with_shape(shapefactoid![..]),
322 begin,
323 end,
324 strides,
325 ]
326 );
327 assert_eq!(
328 output_facts,
329 tvec![InferenceFact::default()
330 .with_datum_type(DatumType::F32)
331 .with_shape(shapefactoid![..]),]
332 );
333 }
334
335 #[test]
336 fn inference_3() {
337 let table = SymbolScope::default();
338 let s = table.new_with_prefix("S").to_dim();
339 let mut op = strided_slice(5, 7, 0);
340 let input = f32::fact(dims!(1, s.clone() - 2, 16)).into();
341 let begin = InferenceFact::from(tensor1(&[0i32, 2, 0]));
342 let end = InferenceFact::from(tensor1(&[0i32, 0, 0]));
343 let strides = InferenceFact::from(tensor1(&[1i32, 1, 1]));
344 let any = InferenceFact::default();
345
346 let (_, output_facts, _) =
347 op.infer_facts(tvec![&input, &begin, &end, &strides], tvec![&any], tvec!()).unwrap();
348
349 assert_eq!(output_facts, tvec![f32::fact(dims!(1, s - 4, 16)).into()]);
350 }
351
352 #[test]
353 fn prep_1() {
354 let op = strided_slice(0, 0, 0);
355 assert_eq!(
356 op.prepare_one_dim(
357 0,
358 &4.to_dim(),
359 &tensor1(&[-1i64]),
360 &tensor1(&[i64::MIN]),
361 &[-1]
362 )
363 .unwrap(),
364 Dim { begin: 3.to_dim(), end: (-1).to_dim(), stride: -1, shrink: false }
365 );
366 }
367
368 #[test]
369 fn prep_pytorch_onnx_bug_workadound() {
370 let op = strided_slice(0, 0, 0);
371 assert_eq!(
372 op.prepare_one_dim(
373 0,
374 &4.to_dim(),
375 &tensor1(&[-1i64]),
376 &tensor1(&[i64::MIN + 1]),
377 &[-1]
378 )
379 .unwrap(),
380 Dim { begin: 3.to_dim(), end: (-1).to_dim(), stride: -1, shrink: false }
381 );
382 }
383}