1use crate::internal::*;
2use crate::num_traits::Zero;
3
4#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
5pub struct Slice {
6 pub axis: usize,
7 pub start: TDim,
8 pub end: TDim,
9}
10
11impl Slice {
12 pub fn new(axis: usize, start: impl ToDim, end: impl ToDim) -> Slice {
13 Slice { axis, start: start.to_dim(), end: end.to_dim() }
14 }
15
16 pub fn suffix(&self, name: &str) -> String {
17 format!("{}.axis{}_{}_{}", name, self.axis, self.start, self.end)
18 }
19
20 pub fn declutter_slice_after_slice(
21 &self,
22 model: &TypedModel,
23 node: &TypedNode,
24 ) -> TractResult<Option<TypedModelPatch>> {
25 let prec = model.node(node.inputs[0].node);
26 if let Some(other) = prec.op_as::<Slice>() {
27 if other.axis == self.axis {
28 return TypedModelPatch::replace_single_op(
29 model,
30 node,
31 &prec.inputs,
32 Slice {
33 axis: self.axis,
34 start: self.start.clone() + &other.start,
35 end: self.end.clone() + &other.start,
36 },
37 )
38 .map(Some);
39 }
40 }
41 Ok(None)
42 }
43}
44
45impl Op for Slice {
46 fn name(&self) -> StaticName {
47 "Slice".into()
48 }
49
50 fn info(&self) -> TractResult<Vec<String>> {
51 Ok(vec![format!("axis: {}, {}..{}", self.axis, self.start, self.end)])
52 }
53
54 op_as_typed_op!();
55
56 fn same_as(&self, other: &dyn Op) -> bool {
57 if let Some(other) = other.downcast_ref::<Self>() {
58 other == self
59 } else {
60 false
61 }
62 }
63}
64
65impl EvalOp for Slice {
66 fn is_stateless(&self) -> bool {
67 true
68 }
69
70 fn eval_with_session(
71 &self,
72 _node_id: usize,
73 session: &SessionState,
74 inputs: TVec<TValue>,
75 ) -> TractResult<TVec<TValue>> {
76 let input = args_1!(inputs);
77 let start = self.start.eval(&session.resolved_symbols).to_usize()?;
78 let end = self.end.eval(&session.resolved_symbols).to_usize()?;
79 eval_slice(&input, self.axis, start, end)
80 }
81}
82
83fn eval_slice(input: &Tensor, axis: usize, start: usize, end: usize) -> TractResult<TVec<TValue>> {
84 if end > input.shape()[axis] || start > end {
85 bail!("Invalid range {}..{} for slicing {:?} on axis {}", start, end, input, axis);
86 }
87 unsafe {
88 let mut shape: TVec<_> = input.shape().into();
89 shape[axis] = end - start;
90 let mut tensor = Tensor::uninitialized_dt(input.datum_type(), &shape)?;
91 tensor.assign_slice_unchecked(.., input, start..end, axis);
92 Ok(tvec!(tensor.into_tvalue()))
93 }
94}
95
96impl TypedOp for Slice {
97 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
98 anyhow::ensure!(inputs.len() == 1, "Slice has one single input");
99 if let (Ok(start), Ok(end), Ok(len)) =
100 (self.start.to_usize(), self.end.to_usize(), inputs[0].shape[self.axis].to_usize())
101 {
102 ensure!(start <= end);
103 ensure!(end <= len);
104 }
105 let mut fact = inputs[0].without_value();
106 fact.shape.set(self.axis, (self.end.clone() - &self.start).to_dim());
107 Ok(tvec!(fact))
108 }
109
110 fn axes_mapping(
111 &self,
112 inputs: &[&TypedFact],
113 outputs: &[&TypedFact],
114 ) -> TractResult<AxesMapping> {
115 let mut mapping = AxesMapping::disconnected(inputs, outputs)?;
116 for (axis, repr) in (0..inputs[0].rank()).zip('a'..) {
117 if self.axis != axis {
118 mapping = mapping
119 .renaming((InOut::In(0), axis), repr)?
120 .linking(repr, (InOut::Out(0), axis))?;
121 }
122 }
123 Ok(mapping)
124 }
125
126 fn change_axes(
127 &self,
128 model: &TypedModel,
129 node: &TypedNode,
130 _io: InOut,
131 change: &AxisOp,
132 ) -> TractResult<Option<AxisChangeConsequence>> {
133 if let Some(axis) = change.transform_axis(self.axis) {
134 if axis != self.axis {
135 Ok(Some(AxisChangeConsequence::new(
136 model,
137 node,
138 Some(Box::new(Slice { axis, ..self.clone() }) as _),
139 change,
140 )))
141 } else {
142 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
143 }
144 } else {
145 Ok(None)
146 }
147 }
148
149 fn declutter(
150 &self,
151 model: &TypedModel,
152 node: &TypedNode,
153 ) -> TractResult<Option<TypedModelPatch>> {
154 if self.start.is_zero() && (self.end == model.outlet_fact(node.inputs[0])?.shape[self.axis])
155 {
156 TypedModelPatch::shunt_one_op(model, node)
157 } else if let Some(p) = self.declutter_slice_after_slice(model, node)? {
158 Ok(Some(p))
159 } else {
160 Ok(None)
161 }
162 }
163
164 fn concretize_dims(
165 &self,
166 _source: &TypedModel,
167 node: &TypedNode,
168 target: &mut TypedModel,
169 mapping: &HashMap<OutletId, OutletId>,
170 values: &SymbolValues,
171 ) -> TractResult<TVec<OutletId>> {
172 let op =
173 Slice { axis: self.axis, start: self.start.eval(values), end: self.end.eval(values) };
174 let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
175 target.wire_node(&node.name, op, &inputs)
176 }
177
178 fn slice(
179 &self,
180 patch: &mut TypedModelPatch,
181 _model: &TypedModel,
182 node: &TypedNode,
183 _prefix: &str,
184 inputs: &[OutletId],
185 _output_axis: usize,
186 _start: &TDim,
187 _end: &TDim,
188 ) -> TractResult<Option<TVec<OutletId>>> {
189 patch.wire_node(&node.name, &node.op, inputs).map(Some)
190 }
191
192 as_op!();
193}