Skip to main content

tract_core/ops/array/
strided_slice.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, Hash)]
4pub struct StridedSlice {
5    pub optional_axes_input: Option<usize>,
6    pub optional_steps_input: Option<usize>,
7    pub begin_mask: i64,
8    pub end_mask: i64,
9    pub shrink_axis_mask: i64,
10}
11
12#[derive(Debug, Clone, PartialEq)]
13pub struct Dim {
14    // position of the first element to return
15    pub begin: TDim,
16    // position of the first element not to return
17    pub end: TDim,
18    pub stride: i32,
19    pub shrink: bool,
20}
21
22impl Dim {
23    pub fn soft_len(&self) -> TractResult<TDim> {
24        if let Ok(len) = (self.end.clone() - &self.begin).to_isize() {
25            Ok((((self.stride.abs() - 1) + len.abs() as i32) / self.stride.abs()).to_dim())
26        } else if self.stride == 1 {
27            Ok(self.end.clone() - &self.begin)
28        } else {
29            bail!("Streaming dimensions with strides are not supported for now")
30        }
31    }
32}
33
34impl StridedSlice {
35    fn must_shrink(&self, ix: usize) -> bool {
36        self.shrink_axis_mask & (1 << ix) != 0
37    }
38    fn ignore_begin(&self, ix: usize) -> bool {
39        self.begin_mask & (1 << ix) != 0
40    }
41    fn ignore_end(&self, ix: usize) -> bool {
42        self.end_mask & (1 << ix) != 0
43    }
44    pub fn prepare_one_dim(
45        &self,
46        ix: usize,
47        dim: &TDim,
48        begin: &Tensor,
49        end: &Tensor,
50        strides: &[i32],
51    ) -> TractResult<Dim> {
52        // cast bouds to Option<Dim>, dealing with ignore from mask, and spec shorted than dim
53        // also for end, magic values in onnx :/
54        let mut begin: Option<TDim> = if ix >= begin.len() {
55            None
56        } else {
57            let begin = begin.cast_to::<TDim>()?;
58            begin.try_as_dense()?.as_slice::<TDim>()?.get(ix).cloned()
59        };
60
61        let mut end: Option<TDim> = if self.ignore_end(ix) || ix >= end.len() {
62            None
63        } else if end.datum_type() == i64::datum_type() {
64            let end = *end.try_as_dense()?.as_slice::<i64>()?.get(ix).unwrap();
65            if end == i64::MAX || end == i64::MIN || end == i64::MIN + 1 || end == (i32::MAX as i64)
66            {
67                None
68            } else {
69                Some(end.to_dim())
70            }
71        } else {
72            let end = end.cast_to::<TDim>()?;
73            end.try_as_dense()?.as_slice::<TDim>()?.get(ix).cloned()
74        };
75
76        let stride = strides.get(ix).cloned().unwrap_or(1);
77
78        // deal with negative indexing
79        fn fix_negative(bound: &mut TDim, dim: &TDim) {
80            let neg = if bound.prove_positive_or_zero() {
81                false
82            } else if bound.prove_negative_or_zero() {
83                true
84            } else {
85                #[allow(clippy::mutable_key_type)]
86                let symbols = bound.symbols();
87                if symbols.len() == 1 {
88                    let sym = symbols.into_iter().next().unwrap();
89                    let values = SymbolValues::default().with(&sym, 100_000_000);
90                    bound.eval(&values).to_isize().unwrap() < 0
91                } else {
92                    false
93                }
94            };
95            if neg {
96                *bound = bound.clone() + dim;
97            }
98        }
99        if let Some(begin) = begin.as_mut() {
100            fix_negative(begin, dim)
101        }
102        if let Some(end) = end.as_mut() {
103            fix_negative(end, dim)
104        }
105
106        if self.must_shrink(ix) {
107            return Ok(Dim {
108                begin: begin.clone().unwrap_or_else(|| 0.to_dim()),
109                end: begin.unwrap_or_else(|| 0.to_dim()) + 1,
110                stride: 1,
111                shrink: true,
112            });
113        }
114
115        // must happen after dealing with must_shrink :/
116        if self.ignore_begin(ix) {
117            begin = None;
118        }
119
120        let mut begin =
121            begin.unwrap_or_else(|| if stride > 0 { 0.to_dim() } else { dim.clone() - 1 });
122        if begin.to_isize().map(|b| b < 0).unwrap_or(false) {
123            if stride < 0 {
124                return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
125            } else {
126                begin = 0.to_dim();
127            }
128        }
129        if let (Ok(b), Ok(d)) = (begin.to_isize(), dim.to_isize()) {
130            if b > d - 1 {
131                if stride > 0 {
132                    return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
133                } else {
134                    begin = (d - 1).to_dim()
135                }
136            }
137        }
138
139        let mut end = end.unwrap_or_else(|| if stride > 0 { dim.clone() } else { (-1).to_dim() });
140        if end.to_isize().map(|e| e < 0).unwrap_or(false) {
141            if stride > 0 {
142                return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
143            } else {
144                end = (-1).to_dim();
145            }
146        }
147        if let (Ok(e), Ok(d)) = (end.to_isize(), dim.to_isize()) {
148            if e > d - 1 {
149                if stride > 0 {
150                    end = d.to_dim()
151                } else {
152                    return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
153                }
154            }
155        }
156        Ok(Dim { begin, end, stride, shrink: false })
157    }
158
159    fn wire(
160        &self,
161        prefix: &str,
162        target: &mut TypedModel,
163        inputs: &[OutletId],
164    ) -> TractResult<TVec<OutletId>> {
165        let params: TVec<Option<Arc<Tensor>>> = inputs[1..]
166            .iter()
167            .map(|i| Ok(target.outlet_fact(*i)?.konst.clone()))
168            .collect::<TractResult<_>>()?;
169        let input_shape = target.outlet_fact(inputs[0])?.shape.clone();
170        let strides: TVec<i32> = if let Some(i) = self.optional_steps_input {
171            let strides = params[i - 1]
172                .as_ref()
173                .context("StridedSlice is typable only if stride is a const")?
174                .cast_to::<i32>()?;
175            strides.try_as_dense()?.as_slice::<i32>()?.into()
176        } else {
177            tvec![1; input_shape.rank()]
178        };
179        let axes: TVec<usize> = if let Some(i) = self.optional_axes_input {
180            let axes = params[i - 1]
181                .as_ref()
182                .context("StridedSlice is typable only if axis is a const")?
183                .cast_to::<i32>()?;
184            axes.try_as_dense()?
185                .as_slice::<i32>()?
186                .iter()
187                .map(|&i| if i < 0 { input_shape.rank() as i32 + i } else { i } as usize)
188                .collect()
189        } else {
190            (0..input_shape.rank()).collect()
191        };
192        let mut wire = inputs[0];
193        let begin = params[0].as_ref();
194        let end = params[1].as_ref();
195        for (ix, &axis) in axes.iter().enumerate() {
196            if let (Some(begin), Some(end)) = (begin, end) {
197                let d = &input_shape[axis];
198                let preped = self.prepare_one_dim(ix, d, begin, end, &strides)?;
199                let (left, right) = if preped.stride > 0 {
200                    (preped.begin, preped.end)
201                } else {
202                    (preped.end + 1, preped.begin + 1)
203                };
204                wire = target.wire_node(
205                    format!("{prefix}.slice-axis-{axis}"),
206                    crate::ops::array::Slice::new(axis, left, right),
207                    [wire].as_ref(),
208                )?[0];
209                if preped.stride != 1 {
210                    wire = target.wire_node(
211                        format!("{prefix}.stride-axis-{axis}"),
212                        crate::ops::downsample::Downsample::new(axis, preped.stride as isize, 0),
213                        [wire].as_ref(),
214                    )?[0];
215                }
216            } else if strides[ix] == 1 {
217                let left = target.wire_node(
218                    format!("{prefix}.slice-axis-{axis}-start"),
219                    crate::ops::array::Slice::new(0, ix, ix + 1),
220                    &[inputs[1]],
221                )?;
222                let left = target.wire_node(
223                    format!("{prefix}.slice-axis-{axis}-start-rm-axis"),
224                    AxisOp::Rm(0),
225                    &left,
226                )?[0];
227                let right = target.wire_node(
228                    format!("{prefix}.slice-axis-{axis}-end"),
229                    crate::ops::array::Slice::new(0, ix, ix + 1),
230                    &[inputs[2]],
231                )?;
232                let right = target.wire_node(
233                    format!("{prefix}.slice-axis-{axis}-end-rm-axis"),
234                    AxisOp::Rm(0),
235                    &right,
236                )?[0];
237                let sym = target.symbols.new_with_prefix("l");
238                wire = target.wire_node(
239                    format!("{prefix}.slice-axis-{axis}"),
240                    crate::ops::array::DynSlice::new(axis, sym.to_dim()),
241                    &[wire, left, right],
242                )?[0];
243            }
244        }
245        let mut shrink = input_shape
246            .iter()
247            .enumerate()
248            .filter(|(ix, _d)| self.must_shrink(*ix))
249            .map(|pair| pair.0)
250            .collect::<Vec<_>>();
251        shrink.sort();
252        for axis in shrink.iter().rev() {
253            wire = target.wire_node(
254                format!("{prefix}.RmDim-{axis}"),
255                AxisOp::Rm(*axis),
256                [wire].as_ref(),
257            )?[0];
258        }
259        target.rename_node(wire.node, prefix)?;
260        Ok(tvec!(wire))
261    }
262}
263
264impl Op for StridedSlice {
265    fn name(&self) -> StaticName {
266        "StridedSlice".into()
267    }
268
269    op_as_typed_op!();
270}
271
272impl EvalOp for StridedSlice {
273    fn is_stateless(&self) -> bool {
274        true
275    }
276
277    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
278        let mut model = TypedModel::default();
279        let scope = inputs.iter().find_map(|i| {
280            i.try_as_dense().ok().and_then(|d| {
281                d.as_slice::<TDim>()
282                    .ok()
283                    .and_then(|slice| slice.iter().find_map(|dim| dim.find_scope()))
284            })
285        });
286        model.symbols = scope.unwrap_or_default();
287        let mut source = tvec!();
288        for (ix, input) in inputs.iter().enumerate() {
289            source.push(
290                model.add_source(
291                    format!("adhoc_input.{ix}"),
292                    input.clone().into_arc_tensor().into(),
293                )?,
294            );
295        }
296        let output = self.wire("adhoc", &mut model, &source)?;
297        model.set_output_outlets(&output)?;
298        model.into_runnable()?.run(inputs)
299    }
300}
301
302impl TypedOp for StridedSlice {
303    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
304        let mut model = TypedModel::default();
305        let mut source = tvec!();
306        for (ix, input) in inputs.iter().enumerate() {
307            source.push(model.add_source(format!("adhoc_input.{ix}"), (*input).clone())?);
308        }
309        let output = self.wire("adhoc", &mut model, &source)?;
310        model.set_output_outlets(&output)?;
311        Ok(tvec!(model.outlet_fact(output[0])?.clone()))
312    }
313
314    fn declutter(
315        &self,
316        model: &TypedModel,
317        node: &TypedNode,
318    ) -> TractResult<Option<TypedModelPatch>> {
319        let mut patch = TypedModelPatch::default();
320        let mut source = tvec!();
321        for &input in &node.inputs {
322            source.push(patch.tap_model(model, input)?);
323        }
324        let output = self.wire(&node.name, &mut patch, &source)?;
325        patch.shunt_outside(model, node.id.into(), output[0])?;
326        Ok(Some(patch))
327    }
328
329    as_op!();
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    fn apply(
337        input: &[i32],
338        start: Option<isize>,
339        end: Option<isize>,
340        stride: Option<isize>,
341    ) -> TValue {
342        // [0,1,2,3,4,5][::2] => [0, 2, 4]
343        let op = StridedSlice {
344            optional_axes_input: None,
345            optional_steps_input: if stride.is_some() { Some(3) } else { None },
346            begin_mask: if start.is_some() { 0 } else { 1 },
347            end_mask: if end.is_some() { 0 } else { 1 },
348            shrink_axis_mask: 0,
349        };
350        let mut inputs = tvec!(
351            tensor1(input).into(),
352            tensor1(&[start.unwrap_or(0) as i32]).into(),
353            tensor1(&[end.unwrap_or(0) as i32]).into(),
354        );
355        if let Some(stride) = stride {
356            inputs.push(tensor1(&[stride as i32]).into());
357        }
358        op.eval(inputs).unwrap().remove(0)
359    }
360
361    #[test]
362    fn numpy_pos_stride() {
363        // [0,1,2,3][::2] => [0, 2]
364        assert_eq!(apply(&[0, 1, 2, 3], None, None, Some(2)), tensor1(&[0, 2]).into());
365    }
366
367    #[test]
368    fn numpy_neg_stride() {
369        // [0,1,2,3][::-2] => [3, 1]
370        assert_eq!(apply(&[0, 1, 2, 3], None, None, Some(-2)), tensor1(&[3, 1]).into());
371    }
372
373    #[test]
374    fn numpy_neg_stride_with_start_even() {
375        // [0,1,2,3][-1::-2] => [3, 1]
376        assert_eq!(apply(&[0, 1, 2, 3], Some(-1), None, Some(-2)), tensor1(&[3, 1]).into());
377    }
378
379    #[test]
380    fn numpy_neg_stride_with_start_odd() {
381        // [0,1,2,3][-1::-2] => [3, 1]
382        assert_eq!(apply(&[0, 1, 2, 3, 4], Some(-1), None, Some(-2)), tensor1(&[4, 2, 0]).into());
383    }
384}