tract_core/ops/array/
dyn_slice.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash, new)]
4pub struct DynSlice {
5    pub axis: usize,
6    pub len: TDim,
7}
8
9impl DynSlice {
10    pub fn suffix(&self) -> String {
11        format!("axis{}", self.axis)
12    }
13}
14
15impl Op for DynSlice {
16    fn name(&self) -> StaticName {
17        "DynSlice".into()
18    }
19
20    fn info(&self) -> TractResult<Vec<String>> {
21        Ok(vec![format!("axis: {}", self.axis)])
22    }
23
24    op_as_typed_op!();
25
26    fn same_as(&self, other: &dyn Op) -> bool {
27        if let Some(other) = other.downcast_ref::<Self>() {
28            other == self
29        } else {
30            false
31        }
32    }
33}
34
35impl EvalOp for DynSlice {
36    fn is_stateless(&self) -> bool {
37        true
38    }
39
40    fn eval_with_session(
41        &self,
42        _node_id: usize,
43        session: &SessionState,
44        inputs: TVec<TValue>,
45    ) -> TractResult<TVec<TValue>> {
46        let start = inputs[1]
47            .cast_to::<TDim>()?
48            .to_scalar::<TDim>()?
49            .eval(&session.resolved_symbols)
50            .to_usize()?;
51        let end = inputs[2]
52            .cast_to::<TDim>()?
53            .to_scalar::<TDim>()?
54            .eval(&session.resolved_symbols)
55            .to_usize()?;
56        ensure!(start <= end);
57        if let Ok(len) = self.len.eval(&session.resolved_symbols).to_usize() {
58            ensure!(start + len == end);
59        }
60        let slice = inputs[0].slice(self.axis, start, end)?;
61        Ok(tvec!(slice.into()))
62    }
63}
64
65impl TypedOp for DynSlice {
66    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
67        ensure!(inputs.len() == 3);
68        let mut fact = inputs[0].without_value();
69        fact.shape.set(self.axis, self.len.clone());
70        Ok(tvec!(fact))
71    }
72
73    fn axes_mapping(
74        &self,
75        inputs: &[&TypedFact],
76        _outputs: &[&TypedFact],
77    ) -> TractResult<AxesMapping> {
78        AxesMapping::natural_for_rank(1, 1, inputs[0].rank())?
79            .with_extra_input(1)?
80            .with_extra_input(2)
81    }
82
83    fn change_axes(
84        &self,
85        model: &TypedModel,
86        node: &TypedNode,
87        io: InOut,
88        change: &AxisOp,
89    ) -> TractResult<Option<AxisChangeConsequence>> {
90        if io == InOut::In(1) || io == InOut::In(2) {
91            return Ok(None);
92        }
93        if let Some(axis) = change.transform_axis(self.axis) {
94            if axis != self.axis {
95                Ok(Some(AxisChangeConsequence::new(
96                    model,
97                    node,
98                    Some(Box::new(DynSlice { axis, ..self.clone() }) as _),
99                    change,
100                )))
101            } else {
102                Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
103            }
104        } else {
105            Ok(None)
106        }
107    }
108
109    fn declutter(
110        &self,
111        model: &TypedModel,
112        node: &TypedNode,
113    ) -> TractResult<Option<TypedModelPatch>> {
114        let inputs = model.node_input_facts(node.id)?;
115        if let (Some(start), Some(end)) = (&inputs[1].konst, &inputs[2].konst) {
116            let start = start.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone();
117            let end = end.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone();
118
119            return Ok(Some(TypedModelPatch::replace_single_op(
120                model,
121                node,
122                &[node.inputs[0]],
123                crate::ops::array::Slice { axis: self.axis, start, end },
124            )?));
125        }
126        Ok(None)
127    }
128
129    as_op!();
130}