tract_core/ops/array/
dyn_slice.rs1use crate::internal::*;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash, new)]
4pub struct DynSlice {
5 pub axis: usize,
6 pub len: TDim,
7}
8
9impl DynSlice {
10 pub fn suffix(&self) -> String {
11 format!("axis{}", self.axis)
12 }
13}
14
15impl Op for DynSlice {
16 fn name(&self) -> StaticName {
17 "DynSlice".into()
18 }
19
20 fn info(&self) -> TractResult<Vec<String>> {
21 Ok(vec![format!("axis: {}", self.axis)])
22 }
23
24 op_as_typed_op!();
25
26 fn same_as(&self, other: &dyn Op) -> bool {
27 if let Some(other) = other.downcast_ref::<Self>() { other == self } else { false }
28 }
29}
30
31impl EvalOp for DynSlice {
32 fn is_stateless(&self) -> bool {
33 true
34 }
35
36 fn eval_with_session(
37 &self,
38 _node_id: usize,
39 session: &TurnState,
40 inputs: TVec<TValue>,
41 ) -> TractResult<TVec<TValue>> {
42 let start = inputs[1]
43 .cast_to::<TDim>()?
44 .try_as_dense()?
45 .to_scalar::<TDim>()?
46 .eval(&session.resolved_symbols)
47 .to_usize()?;
48 let end = inputs[2]
49 .cast_to::<TDim>()?
50 .try_as_dense()?
51 .to_scalar::<TDim>()?
52 .eval(&session.resolved_symbols)
53 .to_usize()?;
54 ensure!(start <= end);
55 if let Ok(len) = self.len.eval(&session.resolved_symbols).to_usize() {
56 ensure!(start + len == end);
57 }
58 let slice = inputs[0].slice(self.axis, start, end)?;
59 Ok(tvec!(slice.into()))
60 }
61}
62
63impl TypedOp for DynSlice {
64 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
65 ensure!(inputs.len() == 3);
66 let mut fact = inputs[0].without_value();
67 fact.shape.set(self.axis, self.len.clone());
68 Ok(tvec!(fact))
69 }
70
71 fn axes_mapping(
72 &self,
73 inputs: &[&TypedFact],
74 _outputs: &[&TypedFact],
75 ) -> TractResult<AxesMapping> {
76 AxesMapping::natural_for_rank(1, 1, inputs[0].rank())?
77 .with_extra_input(1)?
78 .with_extra_input(2)
79 }
80
81 fn change_axes(
82 &self,
83 model: &TypedModel,
84 node: &TypedNode,
85 io: InOut,
86 change: &AxisOp,
87 ) -> TractResult<Option<AxisChangeConsequence>> {
88 rule_if!(io != InOut::In(1) && io != InOut::In(2));
89 rule_if_some!(axis = change.transform_axis(self.axis));
90 if axis != self.axis {
91 Ok(Some(AxisChangeConsequence::new(
92 model,
93 node,
94 Some(Box::new(DynSlice { axis, ..self.clone() }) as _),
95 change,
96 )))
97 } else {
98 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
99 }
100 }
101
102 fn declutter(
103 &self,
104 model: &TypedModel,
105 node: &TypedNode,
106 ) -> TractResult<Option<TypedModelPatch>> {
107 let inputs = model.node_input_facts(node.id)?;
108 rule_if_some!(start = &inputs[1].konst);
109 rule_if_some!(end = &inputs[2].konst);
110 let start = start.cast_to::<TDim>()?.try_as_dense()?.to_scalar::<TDim>()?.clone();
111 let end = end.cast_to::<TDim>()?.try_as_dense()?.to_scalar::<TDim>()?.clone();
112
113 Ok(Some(TypedModelPatch::replace_single_op(
114 model,
115 node,
116 &[node.inputs[0]],
117 crate::ops::array::Slice { axis: self.axis, start, end },
118 )?))
119 }
120
121 as_op!();
122}