Skip to main content

tract_core/ops/downsample/
mod.rs

1use crate::internal::*;
2use crate::ops;
3use ndarray::prelude::*;
4
5use super::identity::Identity;
6
7mod array;
8mod conv;
9mod scan;
10
11#[derive(Debug, Clone, new, Default, PartialEq, Eq, Hash)]
12pub struct Downsample {
13    pub axis: usize,
14    pub stride: isize,
15    pub modulo: usize,
16}
17
18impl Downsample {
19    pub(crate) fn transform_dim(&self, input_dim: &TDim) -> TDim {
20        (input_dim.clone() - self.modulo).div_ceil(self.stride.unsigned_abs() as u64)
21    }
22
23    pub(crate) fn transform_fact(&self, input_fact: &TypedFact) -> TractResult<TypedFact> {
24        let mut downed = input_fact.clone();
25        let down_len = self.transform_dim(&input_fact.shape[self.axis]);
26        downed.shape.set(self.axis, down_len);
27        if let Some(k) = downed.konst {
28            let mut outputs = self.eval(tvec!(k.into_tvalue()))?;
29            downed.konst = Some(outputs.remove(0).into_arc_tensor())
30        }
31        if cfg!(debug_assertions) {
32            downed.consistent()?;
33        }
34        Ok(downed)
35    }
36}
37
38impl Op for Downsample {
39    fn name(&self) -> StaticName {
40        "Downsample".into()
41    }
42
43    fn info(&self) -> TractResult<Vec<String>> {
44        Ok(vec![format!("axis:{} stride:{} modulo:{}", self.axis, self.stride, self.modulo)])
45    }
46
47    op_as_typed_op!();
48}
49
50impl EvalOp for Downsample {
51    fn is_stateless(&self) -> bool {
52        true
53    }
54
55    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
56        let input = args_1!(inputs);
57        unsafe {
58            let t = if self.modulo > input.shape()[self.axis] {
59                let mut shape: TVec<usize> = input.shape().into();
60                shape[self.axis] = 0;
61                Tensor::uninitialized_dt(input.datum_type(), &shape)?
62            } else {
63                let slice = ndarray::Slice::new(self.modulo as isize, None, self.stride);
64                unsafe fn do_slice<T: Datum>(
65                    t: &Tensor,
66                    axis: usize,
67                    slice: ndarray::Slice,
68                ) -> Tensor {
69                    unsafe {
70                        let dt = t.datum_type();
71                        let mut t2 = t
72                            .to_array_view_unchecked::<T>()
73                            .slice_axis(Axis(axis), slice)
74                            .into_owned()
75                            .into_tensor();
76                        t2.set_datum_type(dt);
77                        t2
78                    }
79                }
80                dispatch_datum_by_size!(do_slice(input.datum_type())(&*input, self.axis, slice))
81            };
82            Ok(tvec!(t.into_tvalue()))
83        }
84    }
85}
86
87impl TypedOp for Downsample {
88    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
89        ensure!(self.axis < inputs[0].rank());
90        ensure!(
91            self.modulo == 0 || self.stride > 0,
92            "non-zero modulo is only defined with forward strides"
93        );
94        let mut downed = inputs[0].without_value();
95        let down_len = self.transform_dim(&downed.shape[self.axis]);
96        downed.shape.set(self.axis, down_len);
97        Ok(tvec!(downed))
98    }
99
100    fn declutter(
101        &self,
102        model: &TypedModel,
103        node: &TypedNode,
104    ) -> TractResult<Option<TypedModelPatch>> {
105        if self.stride == 1 {
106            return Ok(Some(TypedModelPatch::replace_single_op(
107                model,
108                node,
109                &node.inputs,
110                Identity,
111            )?));
112        }
113        pull_downsample_up(model, node)
114            .with_context(|| format!("Pulling {} over {}", node, model.node(node.inputs[0].node)))
115    }
116
117    as_op!();
118}
119
120fn pull_downsample_up(
121    model: &TypedModel,
122    down_node: &TypedNode,
123) -> TractResult<Option<TypedModelPatch>> {
124    model.check_consistency()?;
125    let down_op = down_node.op_as::<Downsample>().unwrap();
126    if let Some(prec) = model.linear_prec(down_node.id)? {
127        let (input_facts, output_facts) = model.node_facts(prec.id)?;
128        let axes_mapping = prec.op.axes_mapping(&input_facts, &output_facts)?;
129        debug!("Consider pull {down_op:?} over {prec:?} (invariants: {axes_mapping:?})");
130        if let Some(slice_op) = prec.op_as::<ops::array::Slice>() {
131            if let Some(p) =
132                array::pull_downsample_over_slice(model, prec, slice_op, down_node, down_op)?
133            {
134                return Ok(Some(p));
135            }
136        } else if let Some(other_op) = prec.op_as::<AxisOp>() {
137            return array::pull_downsample_over_axis_op(model, prec, other_op, down_node, down_op);
138        } else if let Some(conv_op) = prec.op_as::<ops::cnn::conv::Conv>() {
139            return conv::fuse_downsample_into_conv(model, prec, conv_op, down_node, down_op);
140        } else if let Some(other_op) = prec.op_as::<ops::scan::Scan>() {
141            return scan::pull_downsample_over_scan(model, prec, other_op, down_node, down_op);
142        }
143        rule_if!(prec.outputs.len() <= 1 && prec.inputs.len() > 0);
144        let axis_info = axes_mapping.axis((InOut::Out(0), down_op.axis))?;
145        let mut patch = TypedModelPatch::default();
146        let mut inputs = vec![];
147        for (ix, (outlet, axis_info)) in prec.inputs.iter().zip(&axis_info.inputs).enumerate() {
148            let mut wire = patch.tap_model(model, *outlet)?;
149            if let &[axis] = &**axis_info {
150                if !patch.outlet_fact(wire)?.shape[axis].is_one() {
151                    let mut op = down_op.clone();
152                    op.axis = axis;
153                    wire = patch.wire_node(
154                        format!("{}.{}-{}", down_node.name, prec.name, ix),
155                        op,
156                        &[wire],
157                    )?[0];
158                }
159            } else {
160                return Ok(None);
161            }
162            inputs.push(wire);
163        }
164        let other = patch.wire_node(&prec.name, prec.op.clone(), &inputs)?;
165        patch.shunt_outside(model, OutletId::new(down_node.id, 0), other[0])?;
166        return Ok(Some(patch));
167    }
168    Ok(None)
169}