1use crate::internal::*;
2use crate::num_traits::Zero;
3
4#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
5pub struct Slice {
6 pub axis: usize,
7 pub start: TDim,
8 pub end: TDim,
9}
10
11impl Slice {
12 pub fn new(axis: usize, start: impl ToDim, end: impl ToDim) -> Slice {
13 Slice { axis, start: start.to_dim(), end: end.to_dim() }
14 }
15
16 pub fn suffix(&self, name: &str) -> String {
17 format!("{}.axis{}_{}_{}", name, self.axis, self.start, self.end)
18 }
19
20 pub fn declutter_slice_after_slice(
21 &self,
22 model: &TypedModel,
23 node: &TypedNode,
24 ) -> TractResult<Option<TypedModelPatch>> {
25 let prec = model.node(node.inputs[0].node);
26 if let Some(other) = prec.op_as::<Slice>()
27 && other.axis == self.axis
28 {
29 return TypedModelPatch::replace_single_op(
30 model,
31 node,
32 &prec.inputs,
33 Slice {
34 axis: self.axis,
35 start: self.start.clone() + &other.start,
36 end: self.end.clone() + &other.start,
37 },
38 )
39 .map(Some);
40 }
41 Ok(None)
42 }
43}
44
45impl Op for Slice {
46 fn name(&self) -> StaticName {
47 "Slice".into()
48 }
49
50 fn info(&self) -> TractResult<Vec<String>> {
51 Ok(vec![format!("axis: {}, {}..{}", self.axis, self.start, self.end)])
52 }
53
54 op_as_typed_op!();
55}
56
57impl EvalOp for Slice {
58 fn is_stateless(&self) -> bool {
59 true
60 }
61
62 fn eval_with_session(
63 &self,
64 _node_id: usize,
65 session: &TurnState,
66 inputs: TVec<TValue>,
67 ) -> TractResult<TVec<TValue>> {
68 let input = args_1!(inputs);
69 let start = self.start.eval(&session.resolved_symbols).to_usize()?;
70 let end = self.end.eval(&session.resolved_symbols).to_usize()?;
71 eval_slice(&input, self.axis, start, end)
72 }
73}
74
75fn eval_slice(input: &Tensor, axis: usize, start: usize, end: usize) -> TractResult<TVec<TValue>> {
76 if end > input.shape()[axis] || start > end {
77 bail!("Invalid range {}..{} for slicing {:?} on axis {}", start, end, input, axis);
78 }
79 unsafe {
80 let mut shape: TVec<_> = input.shape().into();
81 shape[axis] = end - start;
82 let mut tensor = Tensor::uninitialized_dt(input.datum_type(), &shape)?;
83 tensor.assign_slice_unchecked(.., input, start..end, axis);
84 Ok(tvec!(tensor.into_tvalue()))
85 }
86}
87
88impl TypedOp for Slice {
89 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
90 anyhow::ensure!(inputs.len() == 1, "Slice has one single input");
91 if let (Ok(start), Ok(end), Ok(len)) =
92 (self.start.to_usize(), self.end.to_usize(), inputs[0].shape[self.axis].to_usize())
93 {
94 ensure!(start <= end);
95 ensure!(end <= len);
96 }
97 let mut fact = inputs[0].without_value();
98 fact.shape.set(self.axis, (self.end.clone() - &self.start).to_dim());
99 Ok(tvec!(fact))
100 }
101
102 fn input_roi(
103 &self,
104 model: &TypedModel,
105 node: &TypedNode,
106 ) -> TractResult<Option<TVec<Option<TDim>>>> {
107 let output_fact = model.outlet_fact(OutletId::new(node.id, 0))?;
108 let Some(roi) = &output_fact.region_of_interest else { return Ok(None) };
109 if self.start.is_zero() {
110 return Ok(Some(tvec![Some(roi.clone())]));
111 }
112 if let Some(sym) = roi
114 .symbols()
115 .into_iter()
116 .find(|s| crate::ops::logic::sym_to_coord_axis(s) == Some(self.axis))
117 {
118 let shifted = TDim::Sym(sym.clone()) + self.start.clone();
119 if let Ok(input_roi) = roi.substitute(&sym, &shifted) {
120 return Ok(Some(tvec![Some(input_roi)]));
121 }
122 }
123 Ok(Some(tvec![Some(roi.clone())]))
125 }
126
127 fn axes_mapping(
128 &self,
129 inputs: &[&TypedFact],
130 outputs: &[&TypedFact],
131 ) -> TractResult<AxesMapping> {
132 let mut mapping = AxesMapping::disconnected(inputs, outputs)?;
133 for (axis, repr) in (0..inputs[0].rank()).zip('a'..) {
134 if self.axis != axis {
135 mapping = mapping
136 .renaming((InOut::In(0), axis), repr)?
137 .linking(repr, (InOut::Out(0), axis))?;
138 }
139 }
140 Ok(mapping)
141 }
142
143 fn change_axes(
144 &self,
145 model: &TypedModel,
146 node: &TypedNode,
147 _io: InOut,
148 change: &AxisOp,
149 ) -> TractResult<Option<AxisChangeConsequence>> {
150 if let Some(axis) = change.transform_axis(self.axis) {
151 if axis != self.axis {
152 Ok(Some(AxisChangeConsequence::new(
153 model,
154 node,
155 Some(Box::new(Slice { axis, ..self.clone() }) as _),
156 change,
157 )))
158 } else {
159 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
160 }
161 } else {
162 Ok(None)
163 }
164 }
165
166 fn declutter(
167 &self,
168 model: &TypedModel,
169 node: &TypedNode,
170 ) -> TractResult<Option<TypedModelPatch>> {
171 if self.start.is_zero() && (self.end == model.outlet_fact(node.inputs[0])?.shape[self.axis])
172 {
173 TypedModelPatch::shunt_one_op(model, node)
174 } else if let Some(p) = self.declutter_slice_after_slice(model, node)? {
175 Ok(Some(p))
176 } else {
177 Ok(None)
178 }
179 }
180
181 fn concretize_dims(
182 &self,
183 _source: &TypedModel,
184 node: &TypedNode,
185 target: &mut TypedModel,
186 mapping: &HashMap<OutletId, OutletId>,
187 values: &SymbolValues,
188 ) -> TractResult<TVec<OutletId>> {
189 let op =
190 Slice { axis: self.axis, start: self.start.eval(values), end: self.end.eval(values) };
191 let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
192 target.wire_node(&node.name, op, &inputs)
193 }
194
195 fn slice(
196 &self,
197 patch: &mut TypedModelPatch,
198 _model: &TypedModel,
199 node: &TypedNode,
200 _prefix: &str,
201 inputs: &[OutletId],
202 _output_axis: usize,
203 _start: &TDim,
204 _end: &TDim,
205 ) -> TractResult<Option<TVec<OutletId>>> {
206 patch.wire_node(&node.name, &node.op, inputs).map(Some)
207 }
208
209 as_op!();
210}