tract_core/ops/array/
dyn_slice.rs1use crate::internal::*;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash, new)]
4pub struct DynSlice {
5 pub axis: usize,
6 pub len: TDim,
7}
8
9impl DynSlice {
10 pub fn suffix(&self) -> String {
11 format!("axis{}", self.axis)
12 }
13}
14
15impl Op for DynSlice {
16 fn name(&self) -> StaticName {
17 "DynSlice".into()
18 }
19
20 fn info(&self) -> TractResult<Vec<String>> {
21 Ok(vec![format!("axis: {}", self.axis)])
22 }
23
24 op_as_typed_op!();
25
26 fn same_as(&self, other: &dyn Op) -> bool {
27 if let Some(other) = other.downcast_ref::<Self>() {
28 other == self
29 } else {
30 false
31 }
32 }
33}
34
35impl EvalOp for DynSlice {
36 fn is_stateless(&self) -> bool {
37 true
38 }
39
40 fn eval_with_session(
41 &self,
42 _node_id: usize,
43 session: &SessionState,
44 inputs: TVec<TValue>,
45 ) -> TractResult<TVec<TValue>> {
46 let start = inputs[1]
47 .cast_to::<TDim>()?
48 .to_scalar::<TDim>()?
49 .eval(&session.resolved_symbols)
50 .to_usize()?;
51 let end = inputs[2]
52 .cast_to::<TDim>()?
53 .to_scalar::<TDim>()?
54 .eval(&session.resolved_symbols)
55 .to_usize()?;
56 ensure!(start <= end);
57 if let Ok(len) = self.len.eval(&session.resolved_symbols).to_usize() {
58 ensure!(start + len == end);
59 }
60 let slice = inputs[0].slice(self.axis, start, end)?;
61 Ok(tvec!(slice.into()))
62 }
63}
64
65impl TypedOp for DynSlice {
66 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
67 ensure!(inputs.len() == 3);
68 let mut fact = inputs[0].without_value();
69 fact.shape.set(self.axis, self.len.clone());
70 Ok(tvec!(fact))
71 }
72
73 fn axes_mapping(
74 &self,
75 inputs: &[&TypedFact],
76 _outputs: &[&TypedFact],
77 ) -> TractResult<AxesMapping> {
78 AxesMapping::natural_for_rank(1, 1, inputs[0].rank())?
79 .with_extra_input(1)?
80 .with_extra_input(2)
81 }
82
83 fn change_axes(
84 &self,
85 model: &TypedModel,
86 node: &TypedNode,
87 io: InOut,
88 change: &AxisOp,
89 ) -> TractResult<Option<AxisChangeConsequence>> {
90 if io == InOut::In(1) || io == InOut::In(2) {
91 return Ok(None);
92 }
93 if let Some(axis) = change.transform_axis(self.axis) {
94 if axis != self.axis {
95 Ok(Some(AxisChangeConsequence::new(
96 model,
97 node,
98 Some(Box::new(DynSlice { axis, ..self.clone() }) as _),
99 change,
100 )))
101 } else {
102 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
103 }
104 } else {
105 Ok(None)
106 }
107 }
108
109 fn declutter(
110 &self,
111 model: &TypedModel,
112 node: &TypedNode,
113 ) -> TractResult<Option<TypedModelPatch>> {
114 let inputs = model.node_input_facts(node.id)?;
115 if let (Some(start), Some(end)) = (&inputs[1].konst, &inputs[2].konst) {
116 let start = start.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone();
117 let end = end.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone();
118
119 return Ok(Some(TypedModelPatch::replace_single_op(
120 model,
121 node,
122 &[node.inputs[0]],
123 crate::ops::array::Slice { axis: self.axis, start, end },
124 )?));
125 }
126 Ok(None)
127 }
128
129 as_op!();
130}