Skip to main content

tract_core/ops/array/
dyn_slice.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash, new)]
4pub struct DynSlice {
5    pub axis: usize,
6    pub len: TDim,
7}
8
9impl DynSlice {
10    pub fn suffix(&self) -> String {
11        format!("axis{}", self.axis)
12    }
13}
14
15impl Op for DynSlice {
16    fn name(&self) -> StaticName {
17        "DynSlice".into()
18    }
19
20    fn info(&self) -> TractResult<Vec<String>> {
21        Ok(vec![format!("axis: {}", self.axis)])
22    }
23
24    op_as_typed_op!();
25}
26
27impl EvalOp for DynSlice {
28    fn is_stateless(&self) -> bool {
29        true
30    }
31
32    fn eval_with_session(
33        &self,
34        _node_id: usize,
35        session: &TurnState,
36        inputs: TVec<TValue>,
37    ) -> TractResult<TVec<TValue>> {
38        let start = inputs[1]
39            .cast_to::<TDim>()?
40            .try_as_plain()?
41            .to_scalar::<TDim>()?
42            .eval(&session.resolved_symbols)
43            .to_usize()?;
44        let end = inputs[2]
45            .cast_to::<TDim>()?
46            .try_as_plain()?
47            .to_scalar::<TDim>()?
48            .eval(&session.resolved_symbols)
49            .to_usize()?;
50        ensure!(start <= end);
51        if let Ok(len) = self.len.eval(&session.resolved_symbols).to_usize() {
52            ensure!(start + len == end);
53        }
54        let slice = inputs[0].slice(self.axis, start, end)?;
55        Ok(tvec!(slice.into()))
56    }
57}
58
59impl TypedOp for DynSlice {
60    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
61        ensure!(inputs.len() == 3);
62        let mut fact = inputs[0].without_value();
63        fact.shape.set(self.axis, self.len.clone());
64        Ok(tvec!(fact))
65    }
66
67    fn axes_mapping(
68        &self,
69        inputs: &[&TypedFact],
70        _outputs: &[&TypedFact],
71    ) -> TractResult<AxesMapping> {
72        AxesMapping::natural_for_rank(1, 1, inputs[0].rank())?
73            .with_extra_input(1)?
74            .with_extra_input(2)
75    }
76
77    fn change_axes(
78        &self,
79        model: &TypedModel,
80        node: &TypedNode,
81        io: InOut,
82        change: &AxisOp,
83    ) -> TractResult<Option<AxisChangeConsequence>> {
84        rule_if!(io != InOut::In(1) && io != InOut::In(2));
85        rule_if_some!(axis = change.transform_axis(self.axis));
86        if axis != self.axis {
87            Ok(Some(AxisChangeConsequence::new(
88                model,
89                node,
90                Some(Box::new(DynSlice { axis, ..self.clone() }) as _),
91                change,
92            )))
93        } else {
94            Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
95        }
96    }
97
98    fn declutter(
99        &self,
100        model: &TypedModel,
101        node: &TypedNode,
102    ) -> TractResult<Option<TypedModelPatch>> {
103        let inputs = model.node_input_facts(node.id)?;
104        rule_if_some!(start = &inputs[1].konst);
105        rule_if_some!(end = &inputs[2].konst);
106        let start = start.cast_to::<TDim>()?.try_as_plain()?.to_scalar::<TDim>()?.clone();
107        let end = end.cast_to::<TDim>()?.try_as_plain()?.to_scalar::<TDim>()?.clone();
108
109        Ok(Some(TypedModelPatch::replace_single_op(
110            model,
111            node,
112            &[node.inputs[0]],
113            crate::ops::array::Slice { axis: self.axis, start, end },
114        )?))
115    }
116
117    as_op!();
118}