tract_core/ops/array/
slice.rs

1use crate::internal::*;
2use crate::num_traits::Zero;
3
4#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
5pub struct Slice {
6    pub axis: usize,
7    pub start: TDim,
8    pub end: TDim,
9}
10
11impl Slice {
12    pub fn new(axis: usize, start: impl ToDim, end: impl ToDim) -> Slice {
13        Slice { axis, start: start.to_dim(), end: end.to_dim() }
14    }
15
16    pub fn suffix(&self, name: &str) -> String {
17        format!("{}.axis{}_{}_{}", name, self.axis, self.start, self.end)
18    }
19
20    pub fn declutter_slice_after_slice(
21        &self,
22        model: &TypedModel,
23        node: &TypedNode,
24    ) -> TractResult<Option<TypedModelPatch>> {
25        let prec = model.node(node.inputs[0].node);
26        if let Some(other) = prec.op_as::<Slice>() {
27            if other.axis == self.axis {
28                return TypedModelPatch::replace_single_op(
29                    model,
30                    node,
31                    &prec.inputs,
32                    Slice {
33                        axis: self.axis,
34                        start: self.start.clone() + &other.start,
35                        end: self.end.clone() + &other.start,
36                    },
37                )
38                .map(Some);
39            }
40        }
41        Ok(None)
42    }
43}
44
45impl Op for Slice {
46    fn name(&self) -> StaticName {
47        "Slice".into()
48    }
49
50    fn info(&self) -> TractResult<Vec<String>> {
51        Ok(vec![format!("axis: {}, {}..{}", self.axis, self.start, self.end)])
52    }
53
54    op_as_typed_op!();
55
56    fn same_as(&self, other: &dyn Op) -> bool {
57        if let Some(other) = other.downcast_ref::<Self>() {
58            other == self
59        } else {
60            false
61        }
62    }
63}
64
65impl EvalOp for Slice {
66    fn is_stateless(&self) -> bool {
67        true
68    }
69
70    fn eval_with_session(
71        &self,
72        _node_id: usize,
73        session: &SessionState,
74        inputs: TVec<TValue>,
75    ) -> TractResult<TVec<TValue>> {
76        let input = args_1!(inputs);
77        let start = self.start.eval(&session.resolved_symbols).to_usize()?;
78        let end = self.end.eval(&session.resolved_symbols).to_usize()?;
79        eval_slice(&input, self.axis, start, end)
80    }
81}
82
83fn eval_slice(input: &Tensor, axis: usize, start: usize, end: usize) -> TractResult<TVec<TValue>> {
84    if end > input.shape()[axis] || start > end {
85        bail!("Invalid range {}..{} for slicing {:?} on axis {}", start, end, input, axis);
86    }
87    unsafe {
88        let mut shape: TVec<_> = input.shape().into();
89        shape[axis] = end - start;
90        let mut tensor = Tensor::uninitialized_dt(input.datum_type(), &shape)?;
91        tensor.assign_slice_unchecked(.., input, start..end, axis);
92        Ok(tvec!(tensor.into_tvalue()))
93    }
94}
95
96impl TypedOp for Slice {
97    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
98        anyhow::ensure!(inputs.len() == 1, "Slice has one single input");
99        if let (Ok(start), Ok(end), Ok(len)) =
100            (self.start.to_usize(), self.end.to_usize(), inputs[0].shape[self.axis].to_usize())
101        {
102            ensure!(start <= end);
103            ensure!(end <= len);
104        }
105        let mut fact = inputs[0].without_value();
106        fact.shape.set(self.axis, (self.end.clone() - &self.start).to_dim());
107        Ok(tvec!(fact))
108    }
109
110    fn axes_mapping(
111        &self,
112        inputs: &[&TypedFact],
113        outputs: &[&TypedFact],
114    ) -> TractResult<AxesMapping> {
115        let mut mapping = AxesMapping::disconnected(inputs, outputs)?;
116        for (axis, repr) in (0..inputs[0].rank()).zip('a'..) {
117            if self.axis != axis {
118                mapping = mapping
119                    .renaming((InOut::In(0), axis), repr)?
120                    .linking(repr, (InOut::Out(0), axis))?;
121            }
122        }
123        Ok(mapping)
124    }
125
126    fn change_axes(
127        &self,
128        model: &TypedModel,
129        node: &TypedNode,
130        _io: InOut,
131        change: &AxisOp,
132    ) -> TractResult<Option<AxisChangeConsequence>> {
133        if let Some(axis) = change.transform_axis(self.axis) {
134            if axis != self.axis {
135                Ok(Some(AxisChangeConsequence::new(
136                    model,
137                    node,
138                    Some(Box::new(Slice { axis, ..self.clone() }) as _),
139                    change,
140                )))
141            } else {
142                Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
143            }
144        } else {
145            Ok(None)
146        }
147    }
148
149    fn declutter(
150        &self,
151        model: &TypedModel,
152        node: &TypedNode,
153    ) -> TractResult<Option<TypedModelPatch>> {
154        if self.start.is_zero() && (self.end == model.outlet_fact(node.inputs[0])?.shape[self.axis])
155        {
156            TypedModelPatch::shunt_one_op(model, node)
157        } else if let Some(p) = self.declutter_slice_after_slice(model, node)? {
158            Ok(Some(p))
159        } else {
160            Ok(None)
161        }
162    }
163
164    fn concretize_dims(
165        &self,
166        _source: &TypedModel,
167        node: &TypedNode,
168        target: &mut TypedModel,
169        mapping: &HashMap<OutletId, OutletId>,
170        values: &SymbolValues,
171    ) -> TractResult<TVec<OutletId>> {
172        let op =
173            Slice { axis: self.axis, start: self.start.eval(values), end: self.end.eval(values) };
174        let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
175        target.wire_node(&node.name, op, &inputs)
176    }
177
178    fn slice(
179        &self,
180        patch: &mut TypedModelPatch,
181        _model: &TypedModel,
182        node: &TypedNode,
183        _prefix: &str,
184        inputs: &[OutletId],
185        _output_axis: usize,
186        _start: &TDim,
187        _end: &TDim,
188    ) -> TractResult<Option<TVec<OutletId>>> {
189        patch.wire_node(&node.name, &node.op, inputs).map(Some)
190    }
191
192    as_op!();
193}