tract_core/ops/array/
dyn_slice.rs1use crate::internal::*;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash, new)]
4pub struct DynSlice {
5 pub axis: usize,
6 pub len: TDim,
7}
8
9impl DynSlice {
10 pub fn suffix(&self) -> String {
11 format!("axis{}", self.axis)
12 }
13}
14
15impl Op for DynSlice {
16 fn name(&self) -> StaticName {
17 "DynSlice".into()
18 }
19
20 fn info(&self) -> TractResult<Vec<String>> {
21 Ok(vec![format!("axis: {}", self.axis)])
22 }
23
24 op_as_typed_op!();
25}
26
27impl EvalOp for DynSlice {
28 fn is_stateless(&self) -> bool {
29 true
30 }
31
32 fn eval_with_session(
33 &self,
34 _node_id: usize,
35 session: &TurnState,
36 inputs: TVec<TValue>,
37 ) -> TractResult<TVec<TValue>> {
38 let start = inputs[1]
39 .cast_to::<TDim>()?
40 .try_as_plain()?
41 .to_scalar::<TDim>()?
42 .eval(&session.resolved_symbols)
43 .to_usize()?;
44 let end = inputs[2]
45 .cast_to::<TDim>()?
46 .try_as_plain()?
47 .to_scalar::<TDim>()?
48 .eval(&session.resolved_symbols)
49 .to_usize()?;
50 ensure!(start <= end);
51 if let Ok(len) = self.len.eval(&session.resolved_symbols).to_usize() {
52 ensure!(start + len == end);
53 }
54 let slice = inputs[0].slice(self.axis, start, end)?;
55 Ok(tvec!(slice.into()))
56 }
57}
58
59impl TypedOp for DynSlice {
60 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
61 ensure!(inputs.len() == 3);
62 let mut fact = inputs[0].without_value();
63 fact.shape.set(self.axis, self.len.clone());
64 Ok(tvec!(fact))
65 }
66
67 fn axes_mapping(
68 &self,
69 inputs: &[&TypedFact],
70 _outputs: &[&TypedFact],
71 ) -> TractResult<AxesMapping> {
72 AxesMapping::natural_for_rank(1, 1, inputs[0].rank())?
73 .with_extra_input(1)?
74 .with_extra_input(2)
75 }
76
77 fn change_axes(
78 &self,
79 model: &TypedModel,
80 node: &TypedNode,
81 io: InOut,
82 change: &AxisOp,
83 ) -> TractResult<Option<AxisChangeConsequence>> {
84 rule_if!(io != InOut::In(1) && io != InOut::In(2));
85 rule_if_some!(axis = change.transform_axis(self.axis));
86 if axis != self.axis {
87 Ok(Some(AxisChangeConsequence::new(
88 model,
89 node,
90 Some(Box::new(DynSlice { axis, ..self.clone() }) as _),
91 change,
92 )))
93 } else {
94 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
95 }
96 }
97
98 fn declutter(
99 &self,
100 model: &TypedModel,
101 node: &TypedNode,
102 ) -> TractResult<Option<TypedModelPatch>> {
103 let inputs = model.node_input_facts(node.id)?;
104 rule_if_some!(start = &inputs[1].konst);
105 rule_if_some!(end = &inputs[2].konst);
106 let start = start.cast_to::<TDim>()?.try_as_plain()?.to_scalar::<TDim>()?.clone();
107 let end = end.cast_to::<TDim>()?.try_as_plain()?.to_scalar::<TDim>()?.clone();
108
109 Ok(Some(TypedModelPatch::replace_single_op(
110 model,
111 node,
112 &[node.inputs[0]],
113 crate::ops::array::Slice { axis: self.axis, start, end },
114 )?))
115 }
116
117 as_op!();
118}