Skip to main content

tract_core/ops/array/
dyn_slice.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash, new)]
4pub struct DynSlice {
5    pub axis: usize,
6    pub len: TDim,
7}
8
9impl DynSlice {
10    pub fn suffix(&self) -> String {
11        format!("axis{}", self.axis)
12    }
13}
14
15impl Op for DynSlice {
16    fn name(&self) -> StaticName {
17        "DynSlice".into()
18    }
19
20    fn info(&self) -> TractResult<Vec<String>> {
21        Ok(vec![format!("axis: {}", self.axis)])
22    }
23
24    op_as_typed_op!();
25
26    fn same_as(&self, other: &dyn Op) -> bool {
27        if let Some(other) = other.downcast_ref::<Self>() { other == self } else { false }
28    }
29}
30
31impl EvalOp for DynSlice {
32    fn is_stateless(&self) -> bool {
33        true
34    }
35
36    fn eval_with_session(
37        &self,
38        _node_id: usize,
39        session: &TurnState,
40        inputs: TVec<TValue>,
41    ) -> TractResult<TVec<TValue>> {
42        let start = inputs[1]
43            .cast_to::<TDim>()?
44            .try_as_dense()?
45            .to_scalar::<TDim>()?
46            .eval(&session.resolved_symbols)
47            .to_usize()?;
48        let end = inputs[2]
49            .cast_to::<TDim>()?
50            .try_as_dense()?
51            .to_scalar::<TDim>()?
52            .eval(&session.resolved_symbols)
53            .to_usize()?;
54        ensure!(start <= end);
55        if let Ok(len) = self.len.eval(&session.resolved_symbols).to_usize() {
56            ensure!(start + len == end);
57        }
58        let slice = inputs[0].slice(self.axis, start, end)?;
59        Ok(tvec!(slice.into()))
60    }
61}
62
63impl TypedOp for DynSlice {
64    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
65        ensure!(inputs.len() == 3);
66        let mut fact = inputs[0].without_value();
67        fact.shape.set(self.axis, self.len.clone());
68        Ok(tvec!(fact))
69    }
70
71    fn axes_mapping(
72        &self,
73        inputs: &[&TypedFact],
74        _outputs: &[&TypedFact],
75    ) -> TractResult<AxesMapping> {
76        AxesMapping::natural_for_rank(1, 1, inputs[0].rank())?
77            .with_extra_input(1)?
78            .with_extra_input(2)
79    }
80
81    fn change_axes(
82        &self,
83        model: &TypedModel,
84        node: &TypedNode,
85        io: InOut,
86        change: &AxisOp,
87    ) -> TractResult<Option<AxisChangeConsequence>> {
88        rule_if!(io != InOut::In(1) && io != InOut::In(2));
89        rule_if_some!(axis = change.transform_axis(self.axis));
90        if axis != self.axis {
91            Ok(Some(AxisChangeConsequence::new(
92                model,
93                node,
94                Some(Box::new(DynSlice { axis, ..self.clone() }) as _),
95                change,
96            )))
97        } else {
98            Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
99        }
100    }
101
102    fn declutter(
103        &self,
104        model: &TypedModel,
105        node: &TypedNode,
106    ) -> TractResult<Option<TypedModelPatch>> {
107        let inputs = model.node_input_facts(node.id)?;
108        rule_if_some!(start = &inputs[1].konst);
109        rule_if_some!(end = &inputs[2].konst);
110        let start = start.cast_to::<TDim>()?.try_as_dense()?.to_scalar::<TDim>()?.clone();
111        let end = end.cast_to::<TDim>()?.try_as_dense()?.to_scalar::<TDim>()?.clone();
112
113        Ok(Some(TypedModelPatch::replace_single_op(
114            model,
115            node,
116            &[node.inputs[0]],
117            crate::ops::array::Slice { axis: self.axis, start, end },
118        )?))
119    }
120
121    as_op!();
122}