tract_core/ops/downsample/
mod.rs1use crate::internal::*;
2use crate::ops;
3use ndarray::prelude::*;
4
5use super::identity::Identity;
6
7mod array;
8mod conv;
9mod scan;
10
11#[derive(Debug, Clone, new, Default, PartialEq, Eq, Hash)]
12pub struct Downsample {
13 pub axis: usize,
14 pub stride: isize,
15 pub modulo: usize,
16}
17
18impl Downsample {
19 pub(crate) fn transform_dim(&self, input_dim: &TDim) -> TDim {
20 (input_dim.clone() - self.modulo).div_ceil(self.stride.unsigned_abs() as u64)
21 }
22
23 pub(crate) fn transform_fact(&self, input_fact: &TypedFact) -> TractResult<TypedFact> {
24 let mut downed = input_fact.clone();
25 let down_len = self.transform_dim(&input_fact.shape[self.axis]);
26 downed.shape.set(self.axis, down_len);
27 if let Some(k) = downed.konst {
28 let mut outputs = self.eval(tvec!(k.into_tvalue()))?;
29 downed.konst = Some(outputs.remove(0).into_arc_tensor())
30 }
31 if cfg!(debug_assertions) {
32 downed.consistent()?;
33 }
34 Ok(downed)
35 }
36}
37
38impl Op for Downsample {
39 fn name(&self) -> StaticName {
40 "Downsample".into()
41 }
42
43 fn info(&self) -> TractResult<Vec<String>> {
44 Ok(vec![format!("axis:{} stride:{} modulo:{}", self.axis, self.stride, self.modulo)])
45 }
46
47 impl_op_same_as!();
48 op_as_typed_op!();
49}
50
51impl EvalOp for Downsample {
52 fn is_stateless(&self) -> bool {
53 true
54 }
55
56 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
57 let input = args_1!(inputs);
58 unsafe {
59 let t = if self.modulo > input.shape()[self.axis] {
60 let mut shape: TVec<usize> = input.shape().into();
61 shape[self.axis] = 0;
62 Tensor::uninitialized_dt(input.datum_type(), &shape)?
63 } else {
64 let slice = ndarray::Slice::new(self.modulo as isize, None, self.stride);
65 unsafe fn do_slice<T: Datum>(
66 t: &Tensor,
67 axis: usize,
68 slice: ndarray::Slice,
69 ) -> Tensor {
70 unsafe {
71 let dt = t.datum_type();
72 let mut t2 = t
73 .to_array_view_unchecked::<T>()
74 .slice_axis(Axis(axis), slice)
75 .into_owned()
76 .into_tensor();
77 t2.set_datum_type(dt);
78 t2
79 }
80 }
81 dispatch_datum_by_size!(do_slice(input.datum_type())(&*input, self.axis, slice))
82 };
83 Ok(tvec!(t.into_tvalue()))
84 }
85 }
86}
87
88impl TypedOp for Downsample {
89 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
90 ensure!(self.axis < inputs[0].rank());
91 ensure!(
92 self.modulo == 0 || self.stride > 0,
93 "non-zero modulo is only defined with forward strides"
94 );
95 let mut downed = inputs[0].without_value();
96 let down_len = self.transform_dim(&downed.shape[self.axis]);
97 downed.shape.set(self.axis, down_len);
98 Ok(tvec!(downed))
99 }
100
101 fn declutter(
102 &self,
103 model: &TypedModel,
104 node: &TypedNode,
105 ) -> TractResult<Option<TypedModelPatch>> {
106 if self.stride == 1 {
107 return Ok(Some(TypedModelPatch::replace_single_op(
108 model,
109 node,
110 &node.inputs,
111 Identity,
112 )?));
113 }
114 pull_downsample_up(model, node)
115 .with_context(|| format!("Pulling {} over {}", node, model.node(node.inputs[0].node)))
116 }
117
118 as_op!();
119}
120
121fn pull_downsample_up(
122 model: &TypedModel,
123 down_node: &TypedNode,
124) -> TractResult<Option<TypedModelPatch>> {
125 model.check_consistency()?;
126 let down_op = down_node.op_as::<Downsample>().unwrap();
127 if let Some(prec) = model.linear_prec(down_node.id)? {
128 let (input_facts, output_facts) = model.node_facts(prec.id)?;
129 let axes_mapping = prec.op.axes_mapping(&input_facts, &output_facts)?;
130 debug!("Consider pull {down_op:?} over {prec:?} (invariants: {axes_mapping:?})");
131 if let Some(slice_op) = prec.op_as::<ops::array::Slice>() {
132 if let Some(p) =
133 array::pull_downsample_over_slice(model, prec, slice_op, down_node, down_op)?
134 {
135 return Ok(Some(p));
136 }
137 } else if let Some(other_op) = prec.op_as::<AxisOp>() {
138 return array::pull_downsample_over_axis_op(model, prec, other_op, down_node, down_op);
139 } else if let Some(conv_op) = prec.op_as::<ops::cnn::conv::Conv>() {
140 return conv::fuse_downsample_into_conv(model, prec, conv_op, down_node, down_op);
141 } else if let Some(other_op) = prec.op_as::<ops::scan::Scan>() {
142 return scan::pull_downsample_over_scan(model, prec, other_op, down_node, down_op);
143 }
144 if prec.outputs.len() > 1 || prec.inputs.len() == 0 {
145 return Ok(None);
146 }
147 let axis_info = axes_mapping.axis((InOut::Out(0), down_op.axis))?;
148 let mut patch = TypedModelPatch::default();
149 let mut inputs = vec![];
150 for (ix, (outlet, axis_info)) in prec.inputs.iter().zip(&axis_info.inputs).enumerate() {
151 let mut wire = patch.tap_model(model, *outlet)?;
152 if let &[axis] = &**axis_info {
153 if !patch.outlet_fact(wire)?.shape[axis].is_one() {
154 let mut op = down_op.clone();
155 op.axis = axis;
156 wire = patch.wire_node(
157 format!("{}.{}-{}", down_node.name, prec.name, ix),
158 op,
159 &[wire],
160 )?[0];
161 }
162 } else {
163 return Ok(None);
164 }
165 inputs.push(wire);
166 }
167 let other = patch.wire_node(&prec.name, prec.op.clone(), &inputs)?;
168 patch.shunt_outside(model, OutletId::new(down_node.id, 0), other[0])?;
169 return Ok(Some(patch));
170 }
171 Ok(None)
172}