Skip to main content

tract_core/ops/array/
strided_slice.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, Hash)]
4pub struct StridedSlice {
5    pub optional_axes_input: Option<usize>,
6    pub optional_steps_input: Option<usize>,
7    pub begin_mask: i64,
8    pub end_mask: i64,
9    pub shrink_axis_mask: i64,
10}
11
12#[derive(Debug, Clone, PartialEq)]
13pub struct Dim {
14    // position of the first element to return
15    pub begin: TDim,
16    // position of the first element not to return
17    pub end: TDim,
18    pub stride: i32,
19    pub shrink: bool,
20}
21
22impl Dim {
23    pub fn soft_len(&self) -> TractResult<TDim> {
24        if let Ok(len) = (self.end.clone() - &self.begin).to_isize() {
25            Ok((((self.stride.abs() - 1) + len.abs() as i32) / self.stride.abs()).to_dim())
26        } else if self.stride == 1 {
27            Ok(self.end.clone() - &self.begin)
28        } else {
29            bail!("Streaming dimensions with strides are not supported for now")
30        }
31    }
32}
33
34impl StridedSlice {
35    fn must_shrink(&self, ix: usize) -> bool {
36        self.shrink_axis_mask & (1 << ix) != 0
37    }
38    fn ignore_begin(&self, ix: usize) -> bool {
39        self.begin_mask & (1 << ix) != 0
40    }
41    fn ignore_end(&self, ix: usize) -> bool {
42        self.end_mask & (1 << ix) != 0
43    }
44    pub fn prepare_one_dim(
45        &self,
46        ix: usize,
47        dim: &TDim,
48        begin: &Tensor,
49        end: &Tensor,
50        strides: &[i32],
51    ) -> TractResult<Dim> {
52        // cast bouds to Option<Dim>, dealing with ignore from mask, and spec shorted than dim
53        // also for end, magic values in onnx :/
54        let mut begin: Option<TDim> = if ix >= begin.len() {
55            None
56        } else {
57            let begin = begin.cast_to::<TDim>()?;
58            begin.as_slice::<TDim>()?.get(ix).cloned()
59        };
60
61        let mut end: Option<TDim> = if self.ignore_end(ix) || ix >= end.len() {
62            None
63        } else if end.datum_type() == i64::datum_type() {
64            let end = *end.as_slice::<i64>()?.get(ix).unwrap();
65            if end == i64::MAX || end == i64::MIN || end == i64::MIN + 1 || end == (i32::MAX as i64)
66            {
67                None
68            } else {
69                Some(end.to_dim())
70            }
71        } else {
72            let end = end.cast_to::<TDim>()?;
73            end.as_slice::<TDim>()?.get(ix).cloned()
74        };
75
76        let stride = strides.get(ix).cloned().unwrap_or(1);
77
78        // deal with negative indexing
79        fn fix_negative(bound: &mut TDim, dim: &TDim) {
80            let neg = if bound.prove_positive_or_zero() {
81                false
82            } else if bound.prove_negative_or_zero() {
83                true
84            } else {
85                #[allow(clippy::mutable_key_type)]
86                let symbols = bound.symbols();
87                if symbols.len() == 1 {
88                    let sym = symbols.into_iter().next().unwrap();
89                    let values = SymbolValues::default().with(&sym, 100_000_000);
90                    bound.eval(&values).to_isize().unwrap() < 0
91                } else {
92                    false
93                }
94            };
95            if neg {
96                *bound = bound.clone() + dim;
97            }
98        }
99        if let Some(begin) = begin.as_mut() {
100            fix_negative(begin, dim)
101        }
102        if let Some(end) = end.as_mut() {
103            fix_negative(end, dim)
104        }
105
106        if self.must_shrink(ix) {
107            return Ok(Dim {
108                begin: begin.clone().unwrap_or_else(|| 0.to_dim()),
109                end: begin.unwrap_or_else(|| 0.to_dim()) + 1,
110                stride: 1,
111                shrink: true,
112            });
113        }
114
115        // must happen after dealing with must_shrink :/
116        if self.ignore_begin(ix) {
117            begin = None;
118        }
119
120        let mut begin =
121            begin.unwrap_or_else(|| if stride > 0 { 0.to_dim() } else { dim.clone() - 1 });
122        if begin.to_isize().map(|b| b < 0).unwrap_or(false) {
123            if stride < 0 {
124                return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
125            } else {
126                begin = 0.to_dim();
127            }
128        }
129        if let (Ok(b), Ok(d)) = (begin.to_isize(), dim.to_isize()) {
130            if b > d - 1 {
131                if stride > 0 {
132                    return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
133                } else {
134                    begin = (d - 1).to_dim()
135                }
136            }
137        }
138
139        let mut end = end.unwrap_or_else(|| if stride > 0 { dim.clone() } else { (-1).to_dim() });
140        if end.to_isize().map(|e| e < 0).unwrap_or(false) {
141            if stride > 0 {
142                return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
143            } else {
144                end = (-1).to_dim();
145            }
146        }
147        if let (Ok(e), Ok(d)) = (end.to_isize(), dim.to_isize()) {
148            if e > d - 1 {
149                if stride > 0 {
150                    end = d.to_dim()
151                } else {
152                    return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
153                }
154            }
155        }
156        Ok(Dim { begin, end, stride, shrink: false })
157    }
158
159    fn wire(
160        &self,
161        prefix: &str,
162        target: &mut TypedModel,
163        inputs: &[OutletId],
164    ) -> TractResult<TVec<OutletId>> {
165        let params: TVec<Option<Arc<Tensor>>> = inputs[1..]
166            .iter()
167            .map(|i| Ok(target.outlet_fact(*i)?.konst.clone()))
168            .collect::<TractResult<_>>()?;
169        let input_shape = target.outlet_fact(inputs[0])?.shape.clone();
170        let strides: TVec<i32> = if let Some(i) = self.optional_steps_input {
171            let strides = params[i - 1]
172                .as_ref()
173                .context("StridedSlice is typable only if stride is a const")?
174                .cast_to::<i32>()?;
175            strides.as_slice::<i32>()?.into()
176        } else {
177            tvec![1; input_shape.rank()]
178        };
179        let axes: TVec<usize> = if let Some(i) = self.optional_axes_input {
180            let axes = params[i - 1]
181                .as_ref()
182                .context("StridedSlice is typable only if axis is a const")?
183                .cast_to::<i32>()?;
184            axes.as_slice::<i32>()?
185                .iter()
186                .map(|&i| if i < 0 { input_shape.rank() as i32 + i } else { i } as usize)
187                .collect()
188        } else {
189            (0..input_shape.rank()).collect()
190        };
191        let mut wire = inputs[0];
192        let begin = params[0].as_ref();
193        let end = params[1].as_ref();
194        for (ix, &axis) in axes.iter().enumerate() {
195            if let (Some(begin), Some(end)) = (begin, end) {
196                let d = &input_shape[axis];
197                let preped = self.prepare_one_dim(ix, d, begin, end, &strides)?;
198                let (left, right) = if preped.stride > 0 {
199                    (preped.begin, preped.end)
200                } else {
201                    (preped.end + 1, preped.begin + 1)
202                };
203                wire = target.wire_node(
204                    format!("{prefix}.slice-axis-{axis}"),
205                    crate::ops::array::Slice::new(axis, left, right),
206                    [wire].as_ref(),
207                )?[0];
208                if preped.stride != 1 {
209                    wire = target.wire_node(
210                        format!("{prefix}.stride-axis-{axis}"),
211                        crate::ops::downsample::Downsample::new(axis, preped.stride as isize, 0),
212                        [wire].as_ref(),
213                    )?[0];
214                }
215            } else if strides[ix] == 1 {
216                let left = target.wire_node(
217                    format!("{prefix}.slice-axis-{axis}-start"),
218                    crate::ops::array::Slice::new(0, ix, ix + 1),
219                    &[inputs[1]],
220                )?;
221                let left = target.wire_node(
222                    format!("{prefix}.slice-axis-{axis}-start-rm-axis"),
223                    AxisOp::Rm(0),
224                    &left,
225                )?[0];
226                let right = target.wire_node(
227                    format!("{prefix}.slice-axis-{axis}-end"),
228                    crate::ops::array::Slice::new(0, ix, ix + 1),
229                    &[inputs[2]],
230                )?;
231                let right = target.wire_node(
232                    format!("{prefix}.slice-axis-{axis}-end-rm-axis"),
233                    AxisOp::Rm(0),
234                    &right,
235                )?[0];
236                let sym = target.symbols.new_with_prefix("l");
237                wire = target.wire_node(
238                    format!("{prefix}.slice-axis-{axis}"),
239                    crate::ops::array::DynSlice::new(axis, sym.to_dim()),
240                    &[wire, left, right],
241                )?[0];
242            }
243        }
244        let mut shrink = input_shape
245            .iter()
246            .enumerate()
247            .filter(|(ix, _d)| self.must_shrink(*ix))
248            .map(|pair| pair.0)
249            .collect::<Vec<_>>();
250        shrink.sort();
251        for axis in shrink.iter().rev() {
252            wire = target.wire_node(
253                format!("{prefix}.RmDim-{axis}"),
254                AxisOp::Rm(*axis),
255                [wire].as_ref(),
256            )?[0];
257        }
258        target.rename_node(wire.node, prefix)?;
259        Ok(tvec!(wire))
260    }
261}
262
263impl Op for StridedSlice {
264    fn name(&self) -> StaticName {
265        "StridedSlice".into()
266    }
267
268    op_as_typed_op!();
269}
270
271impl EvalOp for StridedSlice {
272    fn is_stateless(&self) -> bool {
273        true
274    }
275
276    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
277        let mut model = TypedModel::default();
278        let scope = inputs.iter().find_map(|i| {
279            i.as_slice::<TDim>()
280                .ok()
281                .and_then(|slice| slice.iter().find_map(|dim| dim.find_scope()))
282        });
283        model.symbols = scope.unwrap_or_default();
284        let mut source = tvec!();
285        for (ix, input) in inputs.iter().enumerate() {
286            source.push(
287                model.add_source(
288                    format!("adhoc_input.{ix}"),
289                    input.clone().into_arc_tensor().into(),
290                )?,
291            );
292        }
293        let output = self.wire("adhoc", &mut model, &source)?;
294        model.set_output_outlets(&output)?;
295        model.into_runnable()?.run(inputs)
296    }
297}
298
299impl TypedOp for StridedSlice {
300    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
301        let mut model = TypedModel::default();
302        let mut source = tvec!();
303        for (ix, input) in inputs.iter().enumerate() {
304            source.push(model.add_source(format!("adhoc_input.{ix}"), (*input).clone())?);
305        }
306        let output = self.wire("adhoc", &mut model, &source)?;
307        model.set_output_outlets(&output)?;
308        Ok(tvec!(model.outlet_fact(output[0])?.clone()))
309    }
310
311    fn declutter(
312        &self,
313        model: &TypedModel,
314        node: &TypedNode,
315    ) -> TractResult<Option<TypedModelPatch>> {
316        let mut patch = TypedModelPatch::default();
317        let mut source = tvec!();
318        for &input in &node.inputs {
319            source.push(patch.tap_model(model, input)?);
320        }
321        let output = self.wire(&node.name, &mut patch, &source)?;
322        patch.shunt_outside(model, node.id.into(), output[0])?;
323        Ok(Some(patch))
324    }
325
326    as_op!();
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    fn apply(
334        input: &[i32],
335        start: Option<isize>,
336        end: Option<isize>,
337        stride: Option<isize>,
338    ) -> TValue {
339        // [0,1,2,3,4,5][::2] => [0, 2, 4]
340        let op = StridedSlice {
341            optional_axes_input: None,
342            optional_steps_input: if stride.is_some() { Some(3) } else { None },
343            begin_mask: if start.is_some() { 0 } else { 1 },
344            end_mask: if end.is_some() { 0 } else { 1 },
345            shrink_axis_mask: 0,
346        };
347        let mut inputs = tvec!(
348            tensor1(input).into(),
349            tensor1(&[start.unwrap_or(0) as i32]).into(),
350            tensor1(&[end.unwrap_or(0) as i32]).into(),
351        );
352        if let Some(stride) = stride {
353            inputs.push(tensor1(&[stride as i32]).into());
354        }
355        op.eval(inputs).unwrap().remove(0)
356    }
357
358    #[test]
359    fn numpy_pos_stride() {
360        // [0,1,2,3][::2] => [0, 2]
361        assert_eq!(apply(&[0, 1, 2, 3], None, None, Some(2)), tensor1(&[0, 2]).into());
362    }
363
364    #[test]
365    fn numpy_neg_stride() {
366        // [0,1,2,3][::-2] => [3, 1]
367        assert_eq!(apply(&[0, 1, 2, 3], None, None, Some(-2)), tensor1(&[3, 1]).into());
368    }
369
370    #[test]
371    fn numpy_neg_stride_with_start_even() {
372        // [0,1,2,3][-1::-2] => [3, 1]
373        assert_eq!(apply(&[0, 1, 2, 3], Some(-1), None, Some(-2)), tensor1(&[3, 1]).into());
374    }
375
376    #[test]
377    fn numpy_neg_stride_with_start_odd() {
378        // [0,1,2,3][-1::-2] => [3, 1]
379        assert_eq!(apply(&[0, 1, 2, 3, 4], Some(-1), None, Some(-2)), tensor1(&[4, 2, 0]).into());
380    }
381}