Skip to main content

tract_core/ops/array/
dyn_slice.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash, new)]
4pub struct DynSlice {
5    pub axis: usize,
6    pub len: TDim,
7}
8
9impl DynSlice {
10    pub fn suffix(&self) -> String {
11        format!("axis{}", self.axis)
12    }
13}
14
15impl Op for DynSlice {
16    fn name(&self) -> StaticName {
17        "DynSlice".into()
18    }
19
20    fn info(&self) -> TractResult<Vec<String>> {
21        Ok(vec![format!("axis: {}", self.axis)])
22    }
23
24    op_as_typed_op!();
25
26    fn same_as(&self, other: &dyn Op) -> bool {
27        if let Some(other) = other.downcast_ref::<Self>() { other == self } else { false }
28    }
29}
30
31impl EvalOp for DynSlice {
32    fn is_stateless(&self) -> bool {
33        true
34    }
35
36    fn eval_with_session(
37        &self,
38        _node_id: usize,
39        session: &TurnState,
40        inputs: TVec<TValue>,
41    ) -> TractResult<TVec<TValue>> {
42        let start = inputs[1]
43            .cast_to::<TDim>()?
44            .to_scalar::<TDim>()?
45            .eval(&session.resolved_symbols)
46            .to_usize()?;
47        let end = inputs[2]
48            .cast_to::<TDim>()?
49            .to_scalar::<TDim>()?
50            .eval(&session.resolved_symbols)
51            .to_usize()?;
52        ensure!(start <= end);
53        if let Ok(len) = self.len.eval(&session.resolved_symbols).to_usize() {
54            ensure!(start + len == end);
55        }
56        let slice = inputs[0].slice(self.axis, start, end)?;
57        Ok(tvec!(slice.into()))
58    }
59}
60
61impl TypedOp for DynSlice {
62    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
63        ensure!(inputs.len() == 3);
64        let mut fact = inputs[0].without_value();
65        fact.shape.set(self.axis, self.len.clone());
66        Ok(tvec!(fact))
67    }
68
69    fn axes_mapping(
70        &self,
71        inputs: &[&TypedFact],
72        _outputs: &[&TypedFact],
73    ) -> TractResult<AxesMapping> {
74        AxesMapping::natural_for_rank(1, 1, inputs[0].rank())?
75            .with_extra_input(1)?
76            .with_extra_input(2)
77    }
78
79    fn change_axes(
80        &self,
81        model: &TypedModel,
82        node: &TypedNode,
83        io: InOut,
84        change: &AxisOp,
85    ) -> TractResult<Option<AxisChangeConsequence>> {
86        rule_if!(io != InOut::In(1) && io != InOut::In(2));
87        rule_if_some!(axis = change.transform_axis(self.axis));
88        if axis != self.axis {
89            Ok(Some(AxisChangeConsequence::new(
90                model,
91                node,
92                Some(Box::new(DynSlice { axis, ..self.clone() }) as _),
93                change,
94            )))
95        } else {
96            Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
97        }
98    }
99
100    fn declutter(
101        &self,
102        model: &TypedModel,
103        node: &TypedNode,
104    ) -> TractResult<Option<TypedModelPatch>> {
105        let inputs = model.node_input_facts(node.id)?;
106        rule_if_some!(start = &inputs[1].konst);
107        rule_if_some!(end = &inputs[2].konst);
108        let start = start.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone();
109        let end = end.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone();
110
111        Ok(Some(TypedModelPatch::replace_single_op(
112            model,
113            node,
114            &[node.inputs[0]],
115            crate::ops::array::Slice { axis: self.axis, start, end },
116        )?))
117    }
118
119    as_op!();
120}