Skip to main content

tract_core/ops/array/
strided_slice.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, Hash, PartialEq, Eq)]
4pub struct StridedSlice {
5    pub optional_axes_input: Option<usize>,
6    pub optional_steps_input: Option<usize>,
7    pub begin_mask: i64,
8    pub end_mask: i64,
9    pub shrink_axis_mask: i64,
10}
11
12#[derive(Debug, Clone, PartialEq)]
13pub struct Dim {
14    // position of the first element to return
15    pub begin: TDim,
16    // position of the first element not to return
17    pub end: TDim,
18    pub stride: i32,
19    pub shrink: bool,
20}
21
22impl Dim {
23    pub fn soft_len(&self) -> TractResult<TDim> {
24        if let Ok(len) = (self.end.clone() - &self.begin).to_isize() {
25            Ok((((self.stride.abs() - 1) + len.abs() as i32) / self.stride.abs()).to_dim())
26        } else if self.stride == 1 {
27            Ok(self.end.clone() - &self.begin)
28        } else {
29            bail!("Streaming dimensions with strides are not supported for now")
30        }
31    }
32}
33
34impl StridedSlice {
35    fn must_shrink(&self, ix: usize) -> bool {
36        self.shrink_axis_mask & (1 << ix) != 0
37    }
38    fn ignore_begin(&self, ix: usize) -> bool {
39        self.begin_mask & (1 << ix) != 0
40    }
41    fn ignore_end(&self, ix: usize) -> bool {
42        self.end_mask & (1 << ix) != 0
43    }
44    pub fn prepare_one_dim(
45        &self,
46        ix: usize,
47        dim: &TDim,
48        begin: &Tensor,
49        end: &Tensor,
50        strides: &[i32],
51    ) -> TractResult<Dim> {
52        // cast bouds to Option<Dim>, dealing with ignore from mask, and spec shorted than dim
53        // also for end, magic values in onnx :/
54        let mut begin: Option<TDim> = if ix >= begin.len() {
55            None
56        } else {
57            let begin = begin.cast_to::<TDim>()?;
58            begin.try_as_plain()?.as_slice::<TDim>()?.get(ix).cloned()
59        };
60
61        let mut end: Option<TDim> = if self.ignore_end(ix) || ix >= end.len() {
62            None
63        } else if end.datum_type() == i64::datum_type() {
64            let end = *end.try_as_plain()?.as_slice::<i64>()?.get(ix).unwrap();
65            if end == i64::MAX || end == i64::MIN || end == i64::MIN + 1 || end == (i32::MAX as i64)
66            {
67                None
68            } else {
69                Some(end.to_dim())
70            }
71        } else {
72            let end = end.cast_to::<TDim>()?;
73            end.try_as_plain()?.as_slice::<TDim>()?.get(ix).cloned()
74        };
75
76        let stride = strides.get(ix).cloned().unwrap_or(1);
77
78        // deal with negative indexing
79        fn fix_negative(bound: &mut TDim, dim: &TDim) {
80            let neg = if bound.prove_positive_or_zero() {
81                false
82            } else if bound.prove_negative_or_zero() {
83                true
84            } else {
85                #[allow(clippy::mutable_key_type)]
86                let symbols = bound.symbols();
87                if symbols.len() == 1 {
88                    let sym = symbols.into_iter().next().unwrap();
89                    let values = SymbolValues::default().with(&sym, 100_000_000);
90                    bound.eval(&values).to_isize().unwrap() < 0
91                } else {
92                    false
93                }
94            };
95            if neg {
96                *bound = bound.clone() + dim;
97            }
98        }
99        if let Some(begin) = begin.as_mut() {
100            fix_negative(begin, dim)
101        }
102        if let Some(end) = end.as_mut() {
103            fix_negative(end, dim)
104        }
105
106        if self.must_shrink(ix) {
107            return Ok(Dim {
108                begin: begin.clone().unwrap_or_else(|| 0.to_dim()),
109                end: begin.unwrap_or_else(|| 0.to_dim()) + 1,
110                stride: 1,
111                shrink: true,
112            });
113        }
114
115        // must happen after dealing with must_shrink :/
116        if self.ignore_begin(ix) {
117            begin = None;
118        }
119
120        let mut begin =
121            begin.unwrap_or_else(|| if stride > 0 { 0.to_dim() } else { dim.clone() - 1 });
122        if begin.to_isize().map(|b| b < 0).unwrap_or(false) {
123            if stride < 0 {
124                return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
125            } else {
126                begin = 0.to_dim();
127            }
128        }
129        if let (Ok(b), Ok(d)) = (begin.to_isize(), dim.to_isize())
130            && b > d - 1
131        {
132            if stride > 0 {
133                return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
134            } else {
135                begin = (d - 1).to_dim()
136            }
137        }
138
139        let mut end = end.unwrap_or_else(|| if stride > 0 { dim.clone() } else { (-1).to_dim() });
140        if end.to_isize().map(|e| e < 0).unwrap_or(false) {
141            if stride > 0 {
142                return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
143            } else {
144                end = (-1).to_dim();
145            }
146        }
147        if let (Ok(e), Ok(d)) = (end.to_isize(), dim.to_isize())
148            && e > d - 1
149        {
150            if stride > 0 {
151                end = d.to_dim()
152            } else {
153                return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
154            }
155        }
156        Ok(Dim { begin, end, stride, shrink: false })
157    }
158
159    fn wire(
160        &self,
161        prefix: &str,
162        target: &mut TypedModel,
163        inputs: &[OutletId],
164    ) -> TractResult<TVec<OutletId>> {
165        let params: TVec<Option<Arc<Tensor>>> = inputs[1..]
166            .iter()
167            .map(|i| Ok(target.outlet_fact(*i)?.konst.clone()))
168            .collect::<TractResult<_>>()?;
169        let input_shape = target.outlet_fact(inputs[0])?.shape.clone();
170        let strides: TVec<i32> = if let Some(i) = self.optional_steps_input {
171            let strides = params[i - 1]
172                .as_ref()
173                .context("StridedSlice is typable only if stride is a const")?
174                .cast_to::<i32>()?;
175            strides.try_as_plain()?.as_slice::<i32>()?.into()
176        } else {
177            tvec![1; input_shape.rank()]
178        };
179        let axes: TVec<usize> = if let Some(i) = self.optional_axes_input {
180            let axes = params[i - 1]
181                .as_ref()
182                .context("StridedSlice is typable only if axis is a const")?
183                .cast_to::<i32>()?;
184            axes.try_as_plain()?
185                .as_slice::<i32>()?
186                .iter()
187                .map(|&i| if i < 0 { input_shape.rank() as i32 + i } else { i } as usize)
188                .collect()
189        } else {
190            (0..input_shape.rank()).collect()
191        };
192        let mut wire = inputs[0];
193        let begin = params[0].as_ref();
194        let end = params[1].as_ref();
195        for (ix, &axis) in axes.iter().enumerate() {
196            if let (Some(begin), Some(end)) = (begin, end) {
197                let d = &input_shape[axis];
198                let preped = self.prepare_one_dim(ix, d, begin, end, &strides)?;
199                let (left, right) = if preped.stride > 0 {
200                    (preped.begin, preped.end)
201                } else {
202                    (preped.end + 1, preped.begin + 1)
203                };
204                wire = target.wire_node(
205                    format!("{prefix}.slice-axis-{axis}"),
206                    crate::ops::array::Slice::new(axis, left, right),
207                    [wire].as_ref(),
208                )?[0];
209                if preped.stride != 1 {
210                    wire = target.wire_node(
211                        format!("{prefix}.stride-axis-{axis}"),
212                        crate::ops::downsample::Downsample::new(axis, preped.stride as isize, 0),
213                        [wire].as_ref(),
214                    )?[0];
215                }
216            } else if strides[ix] == 1 {
217                let left = target.wire_node(
218                    format!("{prefix}.slice-axis-{axis}-start"),
219                    crate::ops::array::Slice::new(0, ix, ix + 1),
220                    &[inputs[1]],
221                )?;
222                let left = target.wire_node(
223                    format!("{prefix}.slice-axis-{axis}-start-rm-axis"),
224                    AxisOp::Rm(0),
225                    &left,
226                )?[0];
227                let right = target.wire_node(
228                    format!("{prefix}.slice-axis-{axis}-end"),
229                    crate::ops::array::Slice::new(0, ix, ix + 1),
230                    &[inputs[2]],
231                )?;
232                let right = target.wire_node(
233                    format!("{prefix}.slice-axis-{axis}-end-rm-axis"),
234                    AxisOp::Rm(0),
235                    &right,
236                )?[0];
237                let sym = target.symbols.new_with_prefix("l");
238                wire = target.wire_node(
239                    format!("{prefix}.slice-axis-{axis}"),
240                    crate::ops::array::DynSlice::new(axis, sym.to_dim()),
241                    &[wire, left, right],
242                )?[0];
243            }
244        }
245        let mut shrink = input_shape
246            .iter()
247            .enumerate()
248            .filter(|(ix, _d)| self.must_shrink(*ix))
249            .map(|pair| pair.0)
250            .collect::<Vec<_>>();
251        shrink.sort();
252        for axis in shrink.iter().rev() {
253            wire = target.wire_node(
254                format!("{prefix}.RmDim-{axis}"),
255                AxisOp::Rm(*axis),
256                [wire].as_ref(),
257            )?[0];
258        }
259        target.rename_node(wire.node, prefix)?;
260        Ok(tvec!(wire))
261    }
262}
263
264impl Op for StridedSlice {
265    fn name(&self) -> StaticName {
266        "StridedSlice".into()
267    }
268
269    op_as_typed_op!();
270}
271
272impl EvalOp for StridedSlice {
273    fn is_stateless(&self) -> bool {
274        true
275    }
276
277    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
278        let mut model = TypedModel::default();
279        let scope = inputs.iter().find_map(|i| {
280            i.try_as_plain().ok().and_then(|d| {
281                d.as_slice::<TDim>()
282                    .ok()
283                    .and_then(|slice| slice.iter().find_map(|dim| dim.find_scope()))
284            })
285        });
286        model.symbols = scope.unwrap_or_default();
287        let mut source = tvec!();
288        for (ix, input) in inputs.iter().enumerate() {
289            source.push(model.add_source(
290                format!("adhoc_input.{ix}"),
291                input.clone().into_arc_tensor().try_into()?,
292            )?);
293        }
294        let output = self.wire("adhoc", &mut model, &source)?;
295        model.select_output_outlets(&output)?;
296        model.into_runnable()?.run(inputs)
297    }
298}
299
300impl TypedOp for StridedSlice {
301    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
302        let mut model = TypedModel::default();
303        let mut source = tvec!();
304        for (ix, input) in inputs.iter().enumerate() {
305            source.push(model.add_source(format!("adhoc_input.{ix}"), (*input).clone())?);
306        }
307        let output = self.wire("adhoc", &mut model, &source)?;
308        model.select_output_outlets(&output)?;
309        Ok(tvec!(model.outlet_fact(output[0])?.clone()))
310    }
311
312    fn declutter(
313        &self,
314        model: &TypedModel,
315        node: &TypedNode,
316    ) -> TractResult<Option<TypedModelPatch>> {
317        let mut patch = TypedModelPatch::default();
318        let mut source = tvec!();
319        for &input in &node.inputs {
320            source.push(patch.tap_model(model, input)?);
321        }
322        let output = self.wire(&node.name, &mut patch, &source)?;
323        patch.shunt_outside(model, node.id.into(), output[0])?;
324        Ok(Some(patch))
325    }
326
327    as_op!();
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    fn apply(
335        input: &[i32],
336        start: Option<isize>,
337        end: Option<isize>,
338        stride: Option<isize>,
339    ) -> TValue {
340        // [0,1,2,3,4,5][::2] => [0, 2, 4]
341        let op = StridedSlice {
342            optional_axes_input: None,
343            optional_steps_input: if stride.is_some() { Some(3) } else { None },
344            begin_mask: if start.is_some() { 0 } else { 1 },
345            end_mask: if end.is_some() { 0 } else { 1 },
346            shrink_axis_mask: 0,
347        };
348        let mut inputs = tvec!(
349            tensor1(input).into(),
350            tensor1(&[start.unwrap_or(0) as i32]).into(),
351            tensor1(&[end.unwrap_or(0) as i32]).into(),
352        );
353        if let Some(stride) = stride {
354            inputs.push(tensor1(&[stride as i32]).into());
355        }
356        op.eval(inputs).unwrap().remove(0)
357    }
358
359    #[test]
360    fn numpy_pos_stride() {
361        // [0,1,2,3][::2] => [0, 2]
362        assert_eq!(apply(&[0, 1, 2, 3], None, None, Some(2)), tensor1(&[0, 2]).into());
363    }
364
365    #[test]
366    fn numpy_neg_stride() {
367        // [0,1,2,3][::-2] => [3, 1]
368        assert_eq!(apply(&[0, 1, 2, 3], None, None, Some(-2)), tensor1(&[3, 1]).into());
369    }
370
371    #[test]
372    fn numpy_neg_stride_with_start_even() {
373        // [0,1,2,3][-1::-2] => [3, 1]
374        assert_eq!(apply(&[0, 1, 2, 3], Some(-1), None, Some(-2)), tensor1(&[3, 1]).into());
375    }
376
377    #[test]
378    fn numpy_neg_stride_with_start_odd() {
379        // [0,1,2,3][-1::-2] => [3, 1]
380        assert_eq!(apply(&[0, 1, 2, 3, 4], Some(-1), None, Some(-2)), tensor1(&[4, 2, 0]).into());
381    }
382}