1use crate::internal::*;
2
3#[derive(Debug, Clone, Hash)]
4pub struct StridedSlice {
5 pub optional_axes_input: Option<usize>,
6 pub optional_steps_input: Option<usize>,
7 pub begin_mask: i64,
8 pub end_mask: i64,
9 pub shrink_axis_mask: i64,
10}
11
12#[derive(Debug, Clone, PartialEq)]
13pub struct Dim {
14 pub begin: TDim,
16 pub end: TDim,
18 pub stride: i32,
19 pub shrink: bool,
20}
21
22impl Dim {
23 pub fn soft_len(&self) -> TractResult<TDim> {
24 if let Ok(len) = (self.end.clone() - &self.begin).to_isize() {
25 Ok((((self.stride.abs() - 1) + len.abs() as i32) / self.stride.abs()).to_dim())
26 } else if self.stride == 1 {
27 Ok(self.end.clone() - &self.begin)
28 } else {
29 bail!("Streaming dimensions with strides are not supported for now")
30 }
31 }
32}
33
34impl StridedSlice {
35 fn must_shrink(&self, ix: usize) -> bool {
36 self.shrink_axis_mask & (1 << ix) != 0
37 }
38 fn ignore_begin(&self, ix: usize) -> bool {
39 self.begin_mask & (1 << ix) != 0
40 }
41 fn ignore_end(&self, ix: usize) -> bool {
42 self.end_mask & (1 << ix) != 0
43 }
44 pub fn prepare_one_dim(
45 &self,
46 ix: usize,
47 dim: &TDim,
48 begin: &Tensor,
49 end: &Tensor,
50 strides: &[i32],
51 ) -> TractResult<Dim> {
52 let mut begin: Option<TDim> = if ix >= begin.len() {
55 None
56 } else {
57 let begin = begin.cast_to::<TDim>()?;
58 begin.try_as_dense()?.as_slice::<TDim>()?.get(ix).cloned()
59 };
60
61 let mut end: Option<TDim> = if self.ignore_end(ix) || ix >= end.len() {
62 None
63 } else if end.datum_type() == i64::datum_type() {
64 let end = *end.try_as_dense()?.as_slice::<i64>()?.get(ix).unwrap();
65 if end == i64::MAX || end == i64::MIN || end == i64::MIN + 1 || end == (i32::MAX as i64)
66 {
67 None
68 } else {
69 Some(end.to_dim())
70 }
71 } else {
72 let end = end.cast_to::<TDim>()?;
73 end.try_as_dense()?.as_slice::<TDim>()?.get(ix).cloned()
74 };
75
76 let stride = strides.get(ix).cloned().unwrap_or(1);
77
78 fn fix_negative(bound: &mut TDim, dim: &TDim) {
80 let neg = if bound.prove_positive_or_zero() {
81 false
82 } else if bound.prove_negative_or_zero() {
83 true
84 } else {
85 #[allow(clippy::mutable_key_type)]
86 let symbols = bound.symbols();
87 if symbols.len() == 1 {
88 let sym = symbols.into_iter().next().unwrap();
89 let values = SymbolValues::default().with(&sym, 100_000_000);
90 bound.eval(&values).to_isize().unwrap() < 0
91 } else {
92 false
93 }
94 };
95 if neg {
96 *bound = bound.clone() + dim;
97 }
98 }
99 if let Some(begin) = begin.as_mut() {
100 fix_negative(begin, dim)
101 }
102 if let Some(end) = end.as_mut() {
103 fix_negative(end, dim)
104 }
105
106 if self.must_shrink(ix) {
107 return Ok(Dim {
108 begin: begin.clone().unwrap_or_else(|| 0.to_dim()),
109 end: begin.unwrap_or_else(|| 0.to_dim()) + 1,
110 stride: 1,
111 shrink: true,
112 });
113 }
114
115 if self.ignore_begin(ix) {
117 begin = None;
118 }
119
120 let mut begin =
121 begin.unwrap_or_else(|| if stride > 0 { 0.to_dim() } else { dim.clone() - 1 });
122 if begin.to_isize().map(|b| b < 0).unwrap_or(false) {
123 if stride < 0 {
124 return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
125 } else {
126 begin = 0.to_dim();
127 }
128 }
129 if let (Ok(b), Ok(d)) = (begin.to_isize(), dim.to_isize()) {
130 if b > d - 1 {
131 if stride > 0 {
132 return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
133 } else {
134 begin = (d - 1).to_dim()
135 }
136 }
137 }
138
139 let mut end = end.unwrap_or_else(|| if stride > 0 { dim.clone() } else { (-1).to_dim() });
140 if end.to_isize().map(|e| e < 0).unwrap_or(false) {
141 if stride > 0 {
142 return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
143 } else {
144 end = (-1).to_dim();
145 }
146 }
147 if let (Ok(e), Ok(d)) = (end.to_isize(), dim.to_isize()) {
148 if e > d - 1 {
149 if stride > 0 {
150 end = d.to_dim()
151 } else {
152 return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false });
153 }
154 }
155 }
156 Ok(Dim { begin, end, stride, shrink: false })
157 }
158
159 fn wire(
160 &self,
161 prefix: &str,
162 target: &mut TypedModel,
163 inputs: &[OutletId],
164 ) -> TractResult<TVec<OutletId>> {
165 let params: TVec<Option<Arc<Tensor>>> = inputs[1..]
166 .iter()
167 .map(|i| Ok(target.outlet_fact(*i)?.konst.clone()))
168 .collect::<TractResult<_>>()?;
169 let input_shape = target.outlet_fact(inputs[0])?.shape.clone();
170 let strides: TVec<i32> = if let Some(i) = self.optional_steps_input {
171 let strides = params[i - 1]
172 .as_ref()
173 .context("StridedSlice is typable only if stride is a const")?
174 .cast_to::<i32>()?;
175 strides.try_as_dense()?.as_slice::<i32>()?.into()
176 } else {
177 tvec![1; input_shape.rank()]
178 };
179 let axes: TVec<usize> = if let Some(i) = self.optional_axes_input {
180 let axes = params[i - 1]
181 .as_ref()
182 .context("StridedSlice is typable only if axis is a const")?
183 .cast_to::<i32>()?;
184 axes.try_as_dense()?
185 .as_slice::<i32>()?
186 .iter()
187 .map(|&i| if i < 0 { input_shape.rank() as i32 + i } else { i } as usize)
188 .collect()
189 } else {
190 (0..input_shape.rank()).collect()
191 };
192 let mut wire = inputs[0];
193 let begin = params[0].as_ref();
194 let end = params[1].as_ref();
195 for (ix, &axis) in axes.iter().enumerate() {
196 if let (Some(begin), Some(end)) = (begin, end) {
197 let d = &input_shape[axis];
198 let preped = self.prepare_one_dim(ix, d, begin, end, &strides)?;
199 let (left, right) = if preped.stride > 0 {
200 (preped.begin, preped.end)
201 } else {
202 (preped.end + 1, preped.begin + 1)
203 };
204 wire = target.wire_node(
205 format!("{prefix}.slice-axis-{axis}"),
206 crate::ops::array::Slice::new(axis, left, right),
207 [wire].as_ref(),
208 )?[0];
209 if preped.stride != 1 {
210 wire = target.wire_node(
211 format!("{prefix}.stride-axis-{axis}"),
212 crate::ops::downsample::Downsample::new(axis, preped.stride as isize, 0),
213 [wire].as_ref(),
214 )?[0];
215 }
216 } else if strides[ix] == 1 {
217 let left = target.wire_node(
218 format!("{prefix}.slice-axis-{axis}-start"),
219 crate::ops::array::Slice::new(0, ix, ix + 1),
220 &[inputs[1]],
221 )?;
222 let left = target.wire_node(
223 format!("{prefix}.slice-axis-{axis}-start-rm-axis"),
224 AxisOp::Rm(0),
225 &left,
226 )?[0];
227 let right = target.wire_node(
228 format!("{prefix}.slice-axis-{axis}-end"),
229 crate::ops::array::Slice::new(0, ix, ix + 1),
230 &[inputs[2]],
231 )?;
232 let right = target.wire_node(
233 format!("{prefix}.slice-axis-{axis}-end-rm-axis"),
234 AxisOp::Rm(0),
235 &right,
236 )?[0];
237 let sym = target.symbols.new_with_prefix("l");
238 wire = target.wire_node(
239 format!("{prefix}.slice-axis-{axis}"),
240 crate::ops::array::DynSlice::new(axis, sym.to_dim()),
241 &[wire, left, right],
242 )?[0];
243 }
244 }
245 let mut shrink = input_shape
246 .iter()
247 .enumerate()
248 .filter(|(ix, _d)| self.must_shrink(*ix))
249 .map(|pair| pair.0)
250 .collect::<Vec<_>>();
251 shrink.sort();
252 for axis in shrink.iter().rev() {
253 wire = target.wire_node(
254 format!("{prefix}.RmDim-{axis}"),
255 AxisOp::Rm(*axis),
256 [wire].as_ref(),
257 )?[0];
258 }
259 target.rename_node(wire.node, prefix)?;
260 Ok(tvec!(wire))
261 }
262}
263
264impl Op for StridedSlice {
265 fn name(&self) -> StaticName {
266 "StridedSlice".into()
267 }
268
269 op_as_typed_op!();
270}
271
272impl EvalOp for StridedSlice {
273 fn is_stateless(&self) -> bool {
274 true
275 }
276
277 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
278 let mut model = TypedModel::default();
279 let scope = inputs.iter().find_map(|i| {
280 i.try_as_dense().ok().and_then(|d| {
281 d.as_slice::<TDim>()
282 .ok()
283 .and_then(|slice| slice.iter().find_map(|dim| dim.find_scope()))
284 })
285 });
286 model.symbols = scope.unwrap_or_default();
287 let mut source = tvec!();
288 for (ix, input) in inputs.iter().enumerate() {
289 source.push(
290 model.add_source(
291 format!("adhoc_input.{ix}"),
292 input.clone().into_arc_tensor().into(),
293 )?,
294 );
295 }
296 let output = self.wire("adhoc", &mut model, &source)?;
297 model.set_output_outlets(&output)?;
298 model.into_runnable()?.run(inputs)
299 }
300}
301
302impl TypedOp for StridedSlice {
303 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
304 let mut model = TypedModel::default();
305 let mut source = tvec!();
306 for (ix, input) in inputs.iter().enumerate() {
307 source.push(model.add_source(format!("adhoc_input.{ix}"), (*input).clone())?);
308 }
309 let output = self.wire("adhoc", &mut model, &source)?;
310 model.set_output_outlets(&output)?;
311 Ok(tvec!(model.outlet_fact(output[0])?.clone()))
312 }
313
314 fn declutter(
315 &self,
316 model: &TypedModel,
317 node: &TypedNode,
318 ) -> TractResult<Option<TypedModelPatch>> {
319 let mut patch = TypedModelPatch::default();
320 let mut source = tvec!();
321 for &input in &node.inputs {
322 source.push(patch.tap_model(model, input)?);
323 }
324 let output = self.wire(&node.name, &mut patch, &source)?;
325 patch.shunt_outside(model, node.id.into(), output[0])?;
326 Ok(Some(patch))
327 }
328
329 as_op!();
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 fn apply(
337 input: &[i32],
338 start: Option<isize>,
339 end: Option<isize>,
340 stride: Option<isize>,
341 ) -> TValue {
342 let op = StridedSlice {
344 optional_axes_input: None,
345 optional_steps_input: if stride.is_some() { Some(3) } else { None },
346 begin_mask: if start.is_some() { 0 } else { 1 },
347 end_mask: if end.is_some() { 0 } else { 1 },
348 shrink_axis_mask: 0,
349 };
350 let mut inputs = tvec!(
351 tensor1(input).into(),
352 tensor1(&[start.unwrap_or(0) as i32]).into(),
353 tensor1(&[end.unwrap_or(0) as i32]).into(),
354 );
355 if let Some(stride) = stride {
356 inputs.push(tensor1(&[stride as i32]).into());
357 }
358 op.eval(inputs).unwrap().remove(0)
359 }
360
361 #[test]
362 fn numpy_pos_stride() {
363 assert_eq!(apply(&[0, 1, 2, 3], None, None, Some(2)), tensor1(&[0, 2]).into());
365 }
366
367 #[test]
368 fn numpy_neg_stride() {
369 assert_eq!(apply(&[0, 1, 2, 3], None, None, Some(-2)), tensor1(&[3, 1]).into());
371 }
372
373 #[test]
374 fn numpy_neg_stride_with_start_even() {
375 assert_eq!(apply(&[0, 1, 2, 3], Some(-1), None, Some(-2)), tensor1(&[3, 1]).into());
377 }
378
379 #[test]
380 fn numpy_neg_stride_with_start_odd() {
381 assert_eq!(apply(&[0, 1, 2, 3, 4], Some(-1), None, Some(-2)), tensor1(&[4, 2, 0]).into());
383 }
384}