tract_hir/ops/array/
strided_slice.rs

1use crate::internal::*;
2use tract_core::ops::array::StridedSlice;
3use tract_itertools::Itertools;
4
5impl InferenceRulesOp for StridedSlice {
6    fn rules<'r, 'p: 'r, 's: 'r>(
7        &'s self,
8        s: &mut Solver<'r>,
9        inputs: &'p [TensorProxy],
10        outputs: &'p [TensorProxy],
11    ) -> InferenceResult {
12        check_input_arity(
13            inputs,
14            3 + self.optional_axes_input.is_some() as usize
15                + self.optional_steps_input.is_some() as usize,
16        )?;
17        check_output_arity(outputs, 1)?;
18        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
19        s.equals(&inputs[1].rank, 1)?;
20        s.equals(&inputs[2].rank, 1)?;
21        s.equals(&inputs[1].shape[0], &inputs[2].shape[0])?;
22        s.equals(
23            &outputs[0].rank,
24            inputs[0].rank.bex() - self.shrink_axis_mask.count_ones() as i64,
25        )?;
26        if let Some(axis) = self.optional_axes_input {
27            s.equals(&inputs[1].shape, &inputs[axis].shape)?;
28        };
29        if let Some(step) = self.optional_steps_input {
30            s.equals(&inputs[1].shape, &inputs[step].shape)?;
31        };
32        if let Some(axes_input) = self.optional_axes_input {
33            s.given(&inputs[axes_input].value, move |s, axes| {
34                let axes = axes.cast_to::<i64>()?.into_owned();
35                s.given(&outputs[0].rank, move |s, orank| {
36                    let axes = axes
37                        .as_slice::<i64>()?
38                        .iter()
39                        .map(|a| if *a >= 0 { *a } else { *a + orank } as usize)
40                        .collect_vec();
41                    let mut iaxis = 0;
42                    for oaxis in 0..orank as usize {
43                        while self.shrink_axis_mask & (1 << iaxis) != 0 {
44                            iaxis += 1;
45                        }
46                        if !axes.contains(&iaxis) {
47                            s.equals(&inputs[0].shape[iaxis], &outputs[0].shape[oaxis])?;
48                        }
49                        iaxis += 1;
50                    }
51                    Ok(())
52                })
53            })?;
54        }
55        s.given(&inputs[0].shape, move |s, input_shape| {
56            s.given_all(inputs[1..].iter().map(|i| &i.value), move |s, params| {
57                let begin = &params[0];
58                let end = &params[1];
59                let strides = if let Some(i) = self.optional_steps_input {
60                    let t = params[i - 1].cast_to::<i32>()?;
61                    t.as_slice::<i32>()?.to_vec()
62                } else {
63                    vec![1; input_shape.len()]
64                };
65                let axes: TVec<usize> = if let Some(i) = self.optional_axes_input {
66                    let axes = params[i - 1].cast_to::<i32>()?;
67                    axes.as_slice::<i32>()?
68                        .iter()
69                        .map(|&i| if i < 0 { input_shape.len() as i32 + i } else { i } as usize)
70                        .collect()
71                } else {
72                    (0..input_shape.len()).collect()
73                };
74                let mut output_shape = input_shape.clone();
75                let mut shrink = vec![];
76                for (ix, axis) in axes.into_iter().enumerate() {
77                    let preped =
78                        self.prepare_one_dim(ix, &input_shape[axis], begin, end, &strides)?;
79                    output_shape[axis] = preped.soft_len()?;
80                    if preped.shrink {
81                        shrink.push(axis);
82                    }
83                }
84                for shrink in shrink.iter().sorted().rev() {
85                    output_shape.remove(*shrink);
86                }
87                s.equals(&outputs[0].shape, output_shape)
88            })
89        })
90    }
91
92    to_typed!();
93    as_op!();
94}
95
96#[cfg(test)]
97mod tests {
98    #![allow(non_snake_case)]
99    use super::*;
100    use tract_core::ops::array::strided_slice::Dim;
101    use tract_ndarray::{arr1, arr2, arr3};
102
103    pub fn strided_slice(begin_mask: i64, end_mask: i64, shrink_axis_mask: i64) -> StridedSlice {
104        StridedSlice {
105            begin_mask,
106            end_mask,
107            shrink_axis_mask,
108            optional_axes_input: None,
109            optional_steps_input: Some(3),
110        }
111    }
112
113    fn eval<I, B, E, S>(op: StridedSlice, input: I, begin: B, end: E, strides: S) -> Tensor
114    where
115        I: Into<Tensor>,
116        B: Into<Tensor>,
117        E: Into<Tensor>,
118        S: Into<Tensor>,
119    {
120        op.eval(tvec![
121            input.into().into(),
122            begin.into().into(),
123            end.into().into(),
124            strides.into().into(),
125        ])
126        .unwrap()
127        .pop()
128        .unwrap()
129        .into_tensor()
130    }
131
132    // https://www.tensorflow.org/api_docs/python/tf/strided_slice
133    #[test]
134    fn eval_1() {
135        assert_eq!(
136            eval(
137                strided_slice(0, 0, 0),
138                arr3(&[[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]],]),
139                tensor1(&[1, 0, 0]),
140                tensor1(&[2, 1, 3]),
141                tensor1(&[1, 1, 1])
142            ),
143            Tensor::from(arr3(&[[[3, 3, 3]]])),
144        );
145    }
146
147    #[test]
148    fn eval_2() {
149        assert_eq!(
150            eval(
151                strided_slice(0, 0, 0),
152                arr3(&[[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]],]),
153                tensor1(&[1, 0, 0]),
154                tensor1(&[2, 2, 3]),
155                tensor1(&[1, 1, 1])
156            ),
157            Tensor::from(arr3(&[[[3, 3, 3], [4, 4, 4]]])),
158        );
159    }
160
161    #[test]
162    fn eval_3_negative_stride() {
163        assert_eq!(
164            eval(
165                strided_slice(0, 0, 0),
166                arr3(&[[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]],]),
167                tensor1(&[1, -1, 0]),
168                tensor1(&[2, -3, 3]),
169                tensor1(&[1, -1, 1])
170            ),
171            Tensor::from(arr3(&[[[4, 4, 4], [3, 3, 3]]])),
172        );
173    }
174
175    #[test]
176    fn eval_3_bis() {
177        assert_eq!(
178            eval(
179                strided_slice(0, 0, 0),
180                arr1(&[0, 1]),
181                tensor1(&[-1]),
182                tensor1(&[-3]),
183                tensor1(&[-1])
184            ),
185            Tensor::from(arr1(&[1, 0]))
186        );
187    }
188
189    #[test]
190    fn eval_4() {
191        assert_eq!(
192            eval(
193                strided_slice(0, 0, 0),
194                tensor3(&[[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]],]),
195                tensor1(&[1, 0, 0]),
196                tensor1(&[2, 2, 4]),
197                tensor1(&[1, 1, 2])
198            ),
199            tensor3(&[[[3, 3], [4, 4]]]),
200        );
201    }
202
203    #[test]
204    fn eval_5() {
205        assert_eq!(
206            eval(
207                strided_slice(0, 0, 0),
208                tensor1(&[0, 0]),
209                tensor1(&[0]),
210                tensor1(&[-1]),
211                tensor1(&[1])
212            ),
213            tensor1(&[0])
214        )
215    }
216
217    #[test]
218    fn eval_6() {
219        assert_eq!(
220            eval(
221                strided_slice(0, 0, 0),
222                tensor2(&[[1, 0, 0, 0], [3, 0, 0, 0], [0, 0, 0, 0]]),
223                tensor1(&[-3, -4]),
224                tensor1(&[-1, -1]),
225                tensor1(&[1, 2])
226            ),
227            tensor2(&[[1, 0], [3, 0]])
228        )
229    }
230
231    #[test]
232    fn eval_7() {
233        assert_eq!(
234            eval(
235                strided_slice(0, 0, 0),
236                tensor2(&[[0, 6], [0, 0]]),
237                tensor1(&[0]),
238                tensor1(&[2]),
239                tensor1(&[1])
240            ),
241            tensor2(&[[0, 6], [0, 0]])
242        )
243    }
244
245    #[test]
246    fn eval_begin_mask_1() {
247        let mut op = strided_slice(0, 0, 0);
248        op.begin_mask = 1;
249        assert_eq!(
250            eval(op, tensor1(&[0, 1]), tensor1(&[1]), tensor1(&[1]), tensor1(&[1])),
251            tensor1(&[0])
252        )
253    }
254
255    #[test]
256    fn eval_shrink_1() {
257        let mut op = strided_slice(0, 0, 0);
258        op.shrink_axis_mask = 1;
259        assert_eq!(
260            eval(op, arr2(&[[0]]), tensor1(&[0, 0]), tensor1(&[0, 0]), tensor1(&[1, 1])),
261            tensor1::<i32>(&[])
262        )
263    }
264
265    #[test]
266    fn eval_shrink_to_scalar() {
267        let mut op = strided_slice(0, 0, 0);
268        op.shrink_axis_mask = 1;
269        assert_eq!(
270            eval(op, tensor1(&[0]), tensor1(&[0]), tensor1(&[0]), tensor1(&[1])),
271            tensor0::<i32>(0)
272        )
273    }
274
275    #[test]
276    fn inference_1() {
277        let mut op = strided_slice(5, 7, 0);
278        let input = InferenceFact::default().with_datum_type(DatumType::F32);
279        let begin = InferenceFact::from(tensor1(&[0i32, 2, 0]));
280        let end = InferenceFact::from(tensor1(&[0i32, 0, 0]));
281        let strides = InferenceFact::from(tensor1(&[1i32, 1, 1]));
282        let any = InferenceFact::default();
283
284        let (input_facts, output_facts, _) =
285            op.infer_facts(tvec![&input, &begin, &end, &strides], tvec![&any], tvec!()).unwrap();
286        assert_eq!(
287            input_facts,
288            tvec![
289                InferenceFact::default()
290                    .with_datum_type(DatumType::F32)
291                    .with_shape(shapefactoid![..]),
292                begin,
293                end,
294                strides,
295            ]
296        );
297        assert_eq!(
298            output_facts,
299            tvec![InferenceFact::default()
300                .with_datum_type(DatumType::F32)
301                .with_shape(shapefactoid![..]),]
302        );
303    }
304
305    #[test]
306    fn inference_2() {
307        let mut op = strided_slice(1, 1, 2);
308        let input = InferenceFact::default().with_datum_type(DatumType::F32);
309        let begin = InferenceFact::from(tensor1(&[0i32, 0]));
310        let end = InferenceFact::from(tensor1(&[0i32, 1]));
311        let strides = InferenceFact::from(tensor1(&[1i32, 1]));
312        let any = InferenceFact::default();
313
314        let (input_facts, output_facts, _) =
315            op.infer_facts(tvec![&input, &begin, &end, &strides], tvec![&any], tvec!()).unwrap();
316        assert_eq!(
317            input_facts,
318            tvec![
319                InferenceFact::default()
320                    .with_datum_type(DatumType::F32)
321                    .with_shape(shapefactoid![..]),
322                begin,
323                end,
324                strides,
325            ]
326        );
327        assert_eq!(
328            output_facts,
329            tvec![InferenceFact::default()
330                .with_datum_type(DatumType::F32)
331                .with_shape(shapefactoid![..]),]
332        );
333    }
334
335    #[test]
336    fn inference_3() {
337        let table = SymbolScope::default();
338        let s = table.new_with_prefix("S").to_dim();
339        let mut op = strided_slice(5, 7, 0);
340        let input = f32::fact(dims!(1, s.clone() - 2, 16)).into();
341        let begin = InferenceFact::from(tensor1(&[0i32, 2, 0]));
342        let end = InferenceFact::from(tensor1(&[0i32, 0, 0]));
343        let strides = InferenceFact::from(tensor1(&[1i32, 1, 1]));
344        let any = InferenceFact::default();
345
346        let (_, output_facts, _) =
347            op.infer_facts(tvec![&input, &begin, &end, &strides], tvec![&any], tvec!()).unwrap();
348
349        assert_eq!(output_facts, tvec![f32::fact(dims!(1, s - 4, 16)).into()]);
350    }
351
352    #[test]
353    fn prep_1() {
354        let op = strided_slice(0, 0, 0);
355        assert_eq!(
356            op.prepare_one_dim(
357                0,
358                &4.to_dim(),
359                &tensor1(&[-1i64]),
360                &tensor1(&[i64::MIN]),
361                &[-1]
362            )
363            .unwrap(),
364            Dim { begin: 3.to_dim(), end: (-1).to_dim(), stride: -1, shrink: false }
365        );
366    }
367
368    #[test]
369    fn prep_pytorch_onnx_bug_workadound() {
370        let op = strided_slice(0, 0, 0);
371        assert_eq!(
372            op.prepare_one_dim(
373                0,
374                &4.to_dim(),
375                &tensor1(&[-1i64]),
376                &tensor1(&[i64::MIN + 1]),
377                &[-1]
378            )
379            .unwrap(),
380            Dim { begin: 3.to_dim(), end: (-1).to_dim(), stride: -1, shrink: false }
381        );
382    }
383}