tract_core/ops/array/
dyn_slice.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash, new)]
4pub struct DynSlice {
5    pub axis: usize,
6    pub len: TDim,
7}
8
9impl DynSlice {
10    pub fn suffix(&self) -> String {
11        format!("axis{}", self.axis)
12    }
13}
14
15impl Op for DynSlice {
16    fn name(&self) -> Cow<str> {
17        "DynSlice".into()
18    }
19
20    fn info(&self) -> TractResult<Vec<String>> {
21        Ok(vec![format!("axis: {}", self.axis)])
22    }
23
24    op_as_typed_op!();
25
26    fn same_as(&self, other: &dyn Op) -> bool {
27        if let Some(other) = other.downcast_ref::<Self>() {
28            other == self
29        } else {
30            false
31        }
32    }
33}
34
35impl EvalOp for DynSlice {
36    fn is_stateless(&self) -> bool {
37        true
38    }
39
40    fn eval_with_session(
41        &self,
42        session: &SessionState,
43        inputs: TVec<TValue>,
44    ) -> TractResult<TVec<TValue>> {
45        let start = inputs[1]
46            .cast_to::<TDim>()?
47            .to_scalar::<TDim>()?
48            .eval(&session.resolved_symbols)
49            .to_usize()?;
50        let end = inputs[2]
51            .cast_to::<TDim>()?
52            .to_scalar::<TDim>()?
53            .eval(&session.resolved_symbols)
54            .to_usize()?;
55        ensure!(start <= end);
56        if let Ok(len) = self.len.eval(&session.resolved_symbols).to_usize() {
57            ensure!(start + len == end);
58        }
59        let slice = inputs[0].slice(self.axis, start, end)?;
60        Ok(tvec!(slice.into()))
61    }
62}
63
64impl TypedOp for DynSlice {
65    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
66        ensure!(inputs.len() == 3);
67        let mut fact = inputs[0].without_value();
68        fact.shape.set(self.axis, self.len.clone());
69        Ok(tvec!(fact))
70    }
71
72    fn axes_mapping(
73        &self,
74        inputs: &[&TypedFact],
75        _outputs: &[&TypedFact],
76    ) -> TractResult<AxesMapping> {
77        AxesMapping::natural_for_rank(1, 1, inputs[0].rank())?
78            .with_extra_input(1)?
79            .with_extra_input(2)
80    }
81
82    fn change_axes(
83        &self,
84        model: &TypedModel,
85        node: &TypedNode,
86        io: InOut,
87        change: &AxisOp,
88    ) -> TractResult<Option<AxisChangeConsequence>> {
89        if io == InOut::In(1) || io == InOut::In(2) {
90            return Ok(None);
91        }
92        if let Some(axis) = change.transform_axis(self.axis) {
93            if axis != self.axis {
94                Ok(Some(AxisChangeConsequence::new(
95                    model,
96                    node,
97                    Some(Box::new(DynSlice { axis, ..self.clone() }) as _),
98                    change,
99                )))
100            } else {
101                Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
102            }
103        } else {
104            Ok(None)
105        }
106    }
107
108    fn declutter(
109        &self,
110        model: &TypedModel,
111        node: &TypedNode,
112    ) -> TractResult<Option<TypedModelPatch>> {
113        let inputs = model.node_input_facts(node.id)?;
114        if let (Some(start), Some(end)) = (&inputs[1].konst, &inputs[2].konst) {
115            let start = start.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone();
116            let end = end.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone();
117
118            return Ok(Some(TypedModelPatch::replace_single_op(
119                model,
120                node,
121                &[node.inputs[0]],
122                crate::ops::array::Slice { axis: self.axis, start, end },
123            )?));
124        }
125        Ok(None)
126    }
127
128    as_op!();
129}