Skip to main content

tract_core/ops/array/
slice.rs

1use crate::internal::*;
2use crate::num_traits::Zero;
3
4#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
5pub struct Slice {
6    pub axis: usize,
7    pub start: TDim,
8    pub end: TDim,
9}
10
11impl Slice {
12    pub fn new(axis: usize, start: impl ToDim, end: impl ToDim) -> Slice {
13        Slice { axis, start: start.to_dim(), end: end.to_dim() }
14    }
15
16    pub fn suffix(&self, name: &str) -> String {
17        format!("{}.axis{}_{}_{}", name, self.axis, self.start, self.end)
18    }
19
20    pub fn declutter_slice_after_slice(
21        &self,
22        model: &TypedModel,
23        node: &TypedNode,
24    ) -> TractResult<Option<TypedModelPatch>> {
25        let prec = model.node(node.inputs[0].node);
26        if let Some(other) = prec.op_as::<Slice>()
27            && other.axis == self.axis
28        {
29            return TypedModelPatch::replace_single_op(
30                model,
31                node,
32                &prec.inputs,
33                Slice {
34                    axis: self.axis,
35                    start: self.start.clone() + &other.start,
36                    end: self.end.clone() + &other.start,
37                },
38            )
39            .map(Some);
40        }
41        Ok(None)
42    }
43}
44
45impl Op for Slice {
46    fn name(&self) -> StaticName {
47        "Slice".into()
48    }
49
50    fn info(&self) -> TractResult<Vec<String>> {
51        Ok(vec![format!("axis: {}, {}..{}", self.axis, self.start, self.end)])
52    }
53
54    op_as_typed_op!();
55}
56
57impl EvalOp for Slice {
58    fn is_stateless(&self) -> bool {
59        true
60    }
61
62    fn eval_with_session(
63        &self,
64        _node_id: usize,
65        session: &TurnState,
66        inputs: TVec<TValue>,
67    ) -> TractResult<TVec<TValue>> {
68        let input = args_1!(inputs);
69        let start = self.start.eval(&session.resolved_symbols).to_usize()?;
70        let end = self.end.eval(&session.resolved_symbols).to_usize()?;
71        eval_slice(&input, self.axis, start, end)
72    }
73}
74
75fn eval_slice(input: &Tensor, axis: usize, start: usize, end: usize) -> TractResult<TVec<TValue>> {
76    if end > input.shape()[axis] || start > end {
77        bail!("Invalid range {}..{} for slicing {:?} on axis {}", start, end, input, axis);
78    }
79    unsafe {
80        let mut shape: TVec<_> = input.shape().into();
81        shape[axis] = end - start;
82        let mut tensor = Tensor::uninitialized_dt(input.datum_type(), &shape)?;
83        tensor.assign_slice_unchecked(.., input, start..end, axis);
84        Ok(tvec!(tensor.into_tvalue()))
85    }
86}
87
88impl TypedOp for Slice {
89    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
90        anyhow::ensure!(inputs.len() == 1, "Slice has one single input");
91        if let (Ok(start), Ok(end), Ok(len)) =
92            (self.start.to_usize(), self.end.to_usize(), inputs[0].shape[self.axis].to_usize())
93        {
94            ensure!(start <= end);
95            ensure!(end <= len);
96        }
97        let mut fact = inputs[0].without_value();
98        fact.shape.set(self.axis, (self.end.clone() - &self.start).to_dim());
99        Ok(tvec!(fact))
100    }
101
102    fn input_roi(
103        &self,
104        model: &TypedModel,
105        node: &TypedNode,
106    ) -> TractResult<Option<TVec<Option<TDim>>>> {
107        let output_fact = model.outlet_fact(OutletId::new(node.id, 0))?;
108        let Some(roi) = &output_fact.region_of_interest else { return Ok(None) };
109        if self.start.is_zero() {
110            return Ok(Some(tvec![Some(roi.clone())]));
111        }
112        // Remap: output 🎯axis = input 🎯axis - start, so substitute 🎯axis → 🎯axis + start
113        if let Some(sym) = roi
114            .symbols()
115            .into_iter()
116            .find(|s| crate::ops::logic::sym_to_coord_axis(s) == Some(self.axis))
117        {
118            let shifted = TDim::Sym(sym.clone()) + self.start.clone();
119            if let Ok(input_roi) = roi.substitute(&sym, &shifted) {
120                return Ok(Some(tvec![Some(input_roi)]));
121            }
122        }
123        // ROI doesn't mention the sliced axis — pass through unchanged
124        Ok(Some(tvec![Some(roi.clone())]))
125    }
126
127    fn axes_mapping(
128        &self,
129        inputs: &[&TypedFact],
130        outputs: &[&TypedFact],
131    ) -> TractResult<AxesMapping> {
132        let mut mapping = AxesMapping::disconnected(inputs, outputs)?;
133        for (axis, repr) in (0..inputs[0].rank()).zip('a'..) {
134            if self.axis != axis {
135                mapping = mapping
136                    .renaming((InOut::In(0), axis), repr)?
137                    .linking(repr, (InOut::Out(0), axis))?;
138            }
139        }
140        Ok(mapping)
141    }
142
143    fn change_axes(
144        &self,
145        model: &TypedModel,
146        node: &TypedNode,
147        _io: InOut,
148        change: &AxisOp,
149    ) -> TractResult<Option<AxisChangeConsequence>> {
150        if let Some(axis) = change.transform_axis(self.axis) {
151            if axis != self.axis {
152                Ok(Some(AxisChangeConsequence::new(
153                    model,
154                    node,
155                    Some(Box::new(Slice { axis, ..self.clone() }) as _),
156                    change,
157                )))
158            } else {
159                Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
160            }
161        } else {
162            Ok(None)
163        }
164    }
165
166    fn declutter(
167        &self,
168        model: &TypedModel,
169        node: &TypedNode,
170    ) -> TractResult<Option<TypedModelPatch>> {
171        if self.start.is_zero() && (self.end == model.outlet_fact(node.inputs[0])?.shape[self.axis])
172        {
173            TypedModelPatch::shunt_one_op(model, node)
174        } else if let Some(p) = self.declutter_slice_after_slice(model, node)? {
175            Ok(Some(p))
176        } else {
177            Ok(None)
178        }
179    }
180
181    fn concretize_dims(
182        &self,
183        _source: &TypedModel,
184        node: &TypedNode,
185        target: &mut TypedModel,
186        mapping: &HashMap<OutletId, OutletId>,
187        values: &SymbolValues,
188    ) -> TractResult<TVec<OutletId>> {
189        let op =
190            Slice { axis: self.axis, start: self.start.eval(values), end: self.end.eval(values) };
191        let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
192        target.wire_node(&node.name, op, &inputs)
193    }
194
195    fn slice(
196        &self,
197        patch: &mut TypedModelPatch,
198        _model: &TypedModel,
199        node: &TypedNode,
200        _prefix: &str,
201        inputs: &[OutletId],
202        _output_axis: usize,
203        _start: &TDim,
204        _end: &TDim,
205    ) -> TractResult<Option<TVec<OutletId>>> {
206        patch.wire_node(&node.name, &node.op, inputs).map(Some)
207    }
208
209    as_op!();
210}