tract_core/ops/array/
strided_slice.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, Hash)]
4pub struct StridedSlice {
5    pub optional_axes_input: Option<usize>,
6    pub optional_steps_input: Option<usize>,
7    pub begin_mask: i64,
8    pub end_mask: i64,
9    pub shrink_axis_mask: i64,
10}
11
12#[derive(Debug, Clone, PartialEq)]
13pub struct Dim {
14    // position of the first element to return
15    pub begin: TDim,
16    // position of the first element not to return
17    pub end: TDim,
18    pub stride: i32,
19    pub shrink: bool,
20}
21
22impl Dim {
23    pub fn soft_len(&self) -> TractResult<TDim> {
24        if let Ok(len) = (self.end.clone() - &self.begin).to_isize() {
25            Ok((((self.stride.abs() - 1) + len.abs() as i32) / self.stride.abs()).to_dim())
26        } else if self.stride == 1 {
27            Ok(self.end.clone() - &self.begin)
28        } else {
29            bail!("Streaming dimensions with strides are not supported for now")
30        }
31    }
32}
33
34impl StridedSlice {
35    fn must_shrink(&self, ix: usize) -> bool {
36        self.shrink_axis_mask & (1 << ix) != 0
37    }
38    fn ignore_begin(&self, ix: usize) -> bool {
39        self.begin_mask & (1 << ix) != 0
40    }
41    fn ignore_end(&self, ix: usize) -> bool {
42        self.end_mask & (1 << ix) != 0
43    }
44    pub fn prepare_one_dim(
45        &self,
46        ix: usize,
47        dim: &TDim,
48        begin: &Tensor,
49        end: &Tensor,
50        strides: &[i32],
51    ) -> TractResult<Dim> {
52        // cast bouds to Option<Dim>, dealing with ignore from mask, and spec shorted than dim
53        // also for end, magic values in onnx :/
54        let mut begin: Option<TDim> = if ix >= begin.len() {
55            None
56        } else {
57            let begin = begin.cast_to::<TDim>()?;
58            begin.as_slice::<TDim>()?.get(ix).cloned()
59        };
60
61        let mut end: Option<TDim> = if self.ignore_end(ix) || ix >= end.len() {
62            None
63        } else if end.datum_type() == i64::datum_type() {
64            let end = *end.as_slice::<i64>()?.get(ix).unwrap();
65            if end == i64::MAX || end == i64::MIN || end == i64::MIN + 1 || end == (i32::MAX as i64)
66            {
67                None
68            } else {
69                Some(end.to_dim())
70            }
71        } else {
72            let end = end.cast_to::<TDim>()?;
73            end.as_slice::<TDim>()?.get(ix).cloned()
74        };
75
76        let stride = strides.get(ix).cloned().unwrap_or(1);
77
78        // deal with negative indexing
79        fn fix_negative(bound: &mut TDim, dim: &TDim) {
80            let neg = if let Ok(b) = bound.to_isize() {
81                b < 0
82            } else {
83                #[allow(clippy::mutable_key_type)]
84                let symbols = bound.symbols();
85                if symbols.len() == 1 {
86                    let sym = symbols.into_iter().next().unwrap();
87                    let values = SymbolValues::default().with(&sym, 100_000_000);
88                    bound.eval(&values).to_isize().unwrap() < 0
89                } else {
90                    false
91                }
92            };
93            if neg {
94                *bound = bound.clone() + dim;
95            }
96        }
97        if let Some(begin) = begin.as_mut() {
98            fix_negative(begin, dim)
99        }
100        if let Some(end) = end.as_mut() {
101            fix_negative(end, dim)
102        }
103
104        if self.must_shrink(ix) {
105            return Ok(Dim {
106                begin: begin.clone().unwrap_or_else(|| 0.to_dim()),
107                end: begin.unwrap_or_else(|| 0.to_dim()) + 1,
108                stride: 1,
109                shrink: true,
110            });
111        }
112
113        // must happen after dealing with must_shrink :/
114        if self.ignore_begin(ix) {
115            begin = None;
116        }
117
118        let mut begin =
119            begin.unwrap_or_else(|| if stride > 0 { 0.to_dim() } else { dim.clone() - 1 });
120        if begin.to_isize().map(|b| b < 0).unwrap_or(false) {
121            if stride < 0 {
122                return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
123            } else {
124                begin = 0.to_dim();
125            }
126        }
127        if let (Ok(b), Ok(d)) = (begin.to_isize(), dim.to_isize()) {
128            if b > d - 1 {
129                if stride > 0 {
130                    return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
131                } else {
132                    begin = (d - 1).to_dim()
133                }
134            }
135        }
136
137        let mut end = end.unwrap_or_else(|| if stride > 0 { dim.clone() } else { (-1).to_dim() });
138        if end.to_isize().map(|e| e < 0).unwrap_or(false) {
139            if stride > 0 {
140                return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
141            } else {
142                end = (-1).to_dim();
143            }
144        }
145        if let (Ok(e), Ok(d)) = (end.to_isize(), dim.to_isize()) {
146            if e > d - 1 {
147                if stride > 0 {
148                    end = d.to_dim()
149                } else {
150                    return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
151                }
152            }
153        }
154        Ok(Dim { begin, end, stride, shrink: false })
155    }
156
157    fn wire(
158        &self,
159        prefix: &str,
160        target: &mut TypedModel,
161        inputs: &[OutletId],
162    ) -> TractResult<TVec<OutletId>> {
163        let params: TVec<Option<Arc<Tensor>>> = inputs[1..]
164            .iter()
165            .map(|i| Ok(target.outlet_fact(*i)?.konst.clone()))
166            .collect::<TractResult<_>>()?;
167        let input_shape = target.outlet_fact(inputs[0])?.shape.clone();
168        let strides: TVec<i32> = if let Some(i) = self.optional_steps_input {
169            let strides = params[i - 1]
170                .as_ref()
171                .context("StridedSlice is typable only if stride is a const")?
172                .cast_to::<i32>()?;
173            strides.as_slice::<i32>()?.into()
174        } else {
175            tvec![1; input_shape.rank()]
176        };
177        let axes: TVec<usize> = if let Some(i) = self.optional_axes_input {
178            let axes = params[i - 1]
179                .as_ref()
180                .context("StridedSlice is typable only if axis is a const")?
181                .cast_to::<i32>()?;
182            axes.as_slice::<i32>()?
183                .iter()
184                .map(|&i| if i < 0 { input_shape.rank() as i32 + i } else { i } as usize)
185                .collect()
186        } else {
187            (0..input_shape.rank()).collect()
188        };
189        let mut wire = inputs[0];
190        let begin = params[0].as_ref();
191        let end = params[1].as_ref();
192        for (ix, &axis) in axes.iter().enumerate() {
193            if let (Some(begin), Some(end)) = (begin, end) {
194                let d = &input_shape[axis];
195                let preped = self.prepare_one_dim(ix, d, begin, end, &strides)?;
196                let (left, right) = if preped.stride > 0 {
197                    (preped.begin, preped.end)
198                } else {
199                    (preped.end + 1, preped.begin + 1)
200                };
201                wire = target.wire_node(
202                    format!("{prefix}.slice-axis-{axis}"),
203                    crate::ops::array::Slice::new(axis, left, right),
204                    [wire].as_ref(),
205                )?[0];
206                if preped.stride != 1 {
207                    wire = target.wire_node(
208                        format!("{prefix}.stride-axis-{axis}"),
209                        crate::ops::downsample::Downsample::new(axis, preped.stride as isize, 0),
210                        [wire].as_ref(),
211                    )?[0];
212                }
213            } else if strides[ix] == 1 {
214                let left = target.wire_node(
215                    format!("{prefix}.slice-axis-{axis}-start"),
216                    crate::ops::array::Slice::new(0, ix, ix + 1),
217                    &[inputs[1]],
218                )?;
219                let left = target.wire_node(
220                    format!("{prefix}.slice-axis-{axis}-start-rm-axis"),
221                    AxisOp::Rm(0),
222                    &left,
223                )?[0];
224                let right = target.wire_node(
225                    format!("{prefix}.slice-axis-{axis}-end"),
226                    crate::ops::array::Slice::new(0, ix, ix + 1),
227                    &[inputs[2]],
228                )?;
229                let right = target.wire_node(
230                    format!("{prefix}.slice-axis-{axis}-end-rm-axis"),
231                    AxisOp::Rm(0),
232                    &right,
233                )?[0];
234                let sym = target.symbols.new_with_prefix("l");
235                wire = target.wire_node(
236                    format!("{prefix}.slice-axis-{axis}"),
237                    crate::ops::array::DynSlice::new(axis, sym.to_dim()),
238                    &[wire, left, right],
239                )?[0];
240            }
241        }
242        let mut shrink = input_shape
243            .iter()
244            .enumerate()
245            .filter(|(ix, _d)| self.must_shrink(*ix))
246            .map(|pair| pair.0)
247            .collect::<Vec<_>>();
248        shrink.sort();
249        for axis in shrink.iter().rev() {
250            wire = target.wire_node(
251                format!("{prefix}.RmDim-{axis}"),
252                AxisOp::Rm(*axis),
253                [wire].as_ref(),
254            )?[0];
255        }
256        target.rename_node(wire.node, prefix)?;
257        Ok(tvec!(wire))
258    }
259}
260
261impl Op for StridedSlice {
262    fn name(&self) -> Cow<str> {
263        "StridedSlice".into()
264    }
265
266    op_as_typed_op!();
267}
268
269impl EvalOp for StridedSlice {
270    fn is_stateless(&self) -> bool {
271        true
272    }
273
274    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
275        let mut model = TypedModel::default();
276        let mut source = tvec!();
277        for (ix, input) in inputs.iter().enumerate() {
278            source.push(model.add_source(
279                format!("adhoc_input.{}", ix),
280                input.clone().into_arc_tensor().into(),
281            )?);
282        }
283        let output = self.wire("adhoc", &mut model, &source)?;
284        model.set_output_outlets(&output)?;
285        model.into_runnable()?.run(inputs)
286    }
287}
288
289impl TypedOp for StridedSlice {
290    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
291        let mut model = TypedModel::default();
292        let mut source = tvec!();
293        for (ix, input) in inputs.iter().enumerate() {
294            source.push(model.add_source(format!("adhoc_input.{}", ix), (*input).clone())?);
295        }
296        let output = self.wire("adhoc", &mut model, &source)?;
297        model.set_output_outlets(&output)?;
298        Ok(tvec!(model.outlet_fact(output[0])?.clone()))
299    }
300
301    fn declutter(
302        &self,
303        model: &TypedModel,
304        node: &TypedNode,
305    ) -> TractResult<Option<TypedModelPatch>> {
306        let mut patch = TypedModelPatch::default();
307        let mut source = tvec!();
308        for &input in &node.inputs {
309            source.push(patch.tap_model(model, input)?);
310        }
311        let output = self.wire(&node.name, &mut patch, &source)?;
312        patch.shunt_outside(model, node.id.into(), output[0])?;
313        Ok(Some(patch))
314    }
315
316    as_op!();
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    fn apply(
324        input: &[i32],
325        start: Option<isize>,
326        end: Option<isize>,
327        stride: Option<isize>,
328    ) -> TValue {
329        // [0,1,2,3,4,5][::2] => [0, 2, 4]
330        let op = StridedSlice {
331            optional_axes_input: None,
332            optional_steps_input: if stride.is_some() { Some(3) } else { None },
333            begin_mask: if start.is_some() { 0 } else { 1 },
334            end_mask: if end.is_some() { 0 } else { 1 },
335            shrink_axis_mask: 0,
336        };
337        let mut inputs = tvec!(
338            tensor1(input).into(),
339            tensor1(&[start.unwrap_or(0) as i32]).into(),
340            tensor1(&[end.unwrap_or(0) as i32]).into(),
341        );
342        if let Some(stride) = stride {
343            inputs.push(tensor1(&[stride as i32]).into());
344        }
345        op.eval(inputs).unwrap().remove(0)
346    }
347
348    #[test]
349    fn numpy_pos_stride() {
350        // [0,1,2,3][::2] => [0, 2]
351        assert_eq!(apply(&[0, 1, 2, 3], None, None, Some(2)), tensor1(&[0, 2]).into());
352    }
353
354    #[test]
355    fn numpy_neg_stride() {
356        // [0,1,2,3][::-2] => [3, 1]
357        assert_eq!(apply(&[0, 1, 2, 3], None, None, Some(-2)), tensor1(&[3, 1]).into());
358    }
359
360    #[test]
361    fn numpy_neg_stride_with_start_even() {
362        // [0,1,2,3][-1::-2] => [3, 1]
363        assert_eq!(apply(&[0, 1, 2, 3], Some(-1), None, Some(-2)), tensor1(&[3, 1]).into());
364    }
365
366    #[test]
367    fn numpy_neg_stride_with_start_odd() {
368        // [0,1,2,3][-1::-2] => [3, 1]
369        assert_eq!(apply(&[0, 1, 2, 3, 4], Some(-1), None, Some(-2)), tensor1(&[4, 2, 0]).into());
370    }
371}