tract_core/ops/cnn/deconv/
deconv.rs

1use crate::internal::*;
2use crate::ops::array::MultiBroadcastTo;
3use crate::ops::cnn::wire_reshape_bias_for_bin;
4use crate::ops::cnn::KernelFormat;
5use crate::ops::cnn::PoolSpec;
6use crate::ops::einsum::EinSum;
7
8#[derive(Clone, Debug, new, Hash)]
9pub struct Deconv {
10    pub pool_spec: PoolSpec,
11    pub kernel_format: KernelFormat,
12    pub adjustments: TVec<usize>,
13    pub group: usize,
14}
15
16impl Deconv {
17    fn wire_with_deconv_sum(
18        &self,
19        name: &str,
20        target: &mut TypedModel,
21        inputs: &[OutletId],
22    ) -> TractResult<TVec<OutletId>> {
23        let input_shape = target.outlet_fact(inputs[0])?.shape.clone();
24        let shape = self.pool_spec.data_format.shape(input_shape.to_tvec())?;
25        let geo_dim = shape.hw_dims().iter().product();
26
27        // collapse H and W together in input: (N) I HW or (N) HW I
28        let mut input = target.wire_node(
29            format!("{name}.reshaped_input"),
30            AxisOp::Reshape(shape.h_axis(), shape.hw_dims().into(), tvec!(geo_dim)),
31            &[inputs[0]],
32        )?;
33
34        // rework input to (N) (G) I/G HW or (N) (G) HW I/G
35        if self.group != 1 {
36            // input is (N) HW I or (N) I HW
37            let i_axis = self.pool_spec.data_format.has_n() as usize
38                + self.pool_spec.data_format.c_is_last() as usize;
39            let i_dim = target.outlet_fact(input[0])?.shape[i_axis].clone();
40            input = target.wire_node(
41                format!("{name}.reshaped_input_for_group"),
42                AxisOp::Reshape(
43                    i_axis,
44                    tvec![i_dim.clone()],
45                    tvec!(self.group.to_dim(), i_dim / self.group),
46                ),
47                &input,
48            )?;
49            if self.pool_spec.data_format.c_is_last() {
50                input = target.wire_node(
51                    format!("{name}.group_axis_left"),
52                    AxisOp::Move(
53                        self.pool_spec.data_format.has_n() as usize + 1,
54                        self.pool_spec.data_format.has_n() as usize,
55                    ),
56                    &input,
57                )?;
58            }
59        }
60
61        let mut kernel = tvec!(inputs[1]);
62        let kernel_fact = target.outlet_fact(kernel[0])?.clone();
63        for (ix, op) in self
64            .kernel_format
65            .kernel_as_group_o_i_hw_ops(&kernel_fact.shape, self.group)
66            .into_iter()
67            .enumerate()
68        {
69            kernel = target.wire_node(format!("{name}.kernel.{ix}"), op, &kernel)?;
70        }
71
72        kernel = target.wire_node(format!("{name}.kernel.mv_i"), AxisOp::Move(2, 3), &kernel)?;
73        kernel =
74            AxisOp::wire_collapse_axis(target, format!("{name}.kernel.col_ohw"), kernel[0], 1)?;
75        if self.group == 1 {
76            kernel = target.wire_node(format!("{name}.kernel.rm_g"), AxisOp::Rm(0), &kernel)?;
77        }
78        let mut expr = if self.pool_spec.data_format.c_is_last() {
79            "gmk,Ngnk->Ngmn".to_string()
80        } else {
81            "gmk,Ngkn->Ngmn".to_string()
82        };
83        if !self.pool_spec.data_format.has_n() {
84            expr = expr.replace('N', "");
85        }
86        if self.group == 1 {
87            expr = expr.replace('g', "");
88        }
89        let einsum = target.wire_node(
90            format!("{name}.einsum"),
91            EinSum { axes: expr.parse()?, operating_dt: kernel_fact.datum_type, q_params: None },
92            &[kernel[0], input[0]],
93        )?;
94
95        let mut bias = wire_reshape_bias_for_bin(
96            target,
97            format!("{name}.reshape_bias"),
98            inputs[2],
99            shape.rank(),
100            shape.c_axis(),
101            self.pool_spec.output_channels,
102        )?[0];
103        let output_shape = super::output_shape(&self.pool_spec, &shape.shape, &self.adjustments)?;
104        bias = target.wire_node(
105            format!("{name}.broadcast_bias"),
106            MultiBroadcastTo { shape: output_shape.into() },
107            &[bias],
108        )?[0];
109
110        // einsum must be (N_)CHkWk_HW
111        let deconv_sum = target.wire_node(
112            format!("{name}.deconv_sum"),
113            super::deconv_sum::DeconvSum::new(
114                self.pool_spec.clone(),
115                self.kernel_format,
116                input_shape,
117                self.adjustments.clone(),
118                self.group,
119            ),
120            &[einsum[0], bias],
121        )?;
122        Ok(deconv_sum)
123    }
124}
125
126impl Op for Deconv {
127    fn name(&self) -> Cow<str> {
128        "Deconv".into()
129    }
130
131    fn info(&self) -> TractResult<Vec<String>> {
132        Ok(vec![format!("{:?}", self.pool_spec)])
133    }
134
135    op_as_typed_op!();
136}
137
138impl EvalOp for Deconv {
139    fn is_stateless(&self) -> bool {
140        true
141    }
142
143    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
144        ensure!(inputs.len() == 3);
145        let mut model = TypedModel::default();
146        let inputs = inputs
147            .into_iter()
148            .enumerate()
149            .map(|(ix, input)| model.add_const(format!("s{ix}"), input.into_tensor()))
150            .collect::<TractResult<TVec<OutletId>>>()?;
151        let output = self.wire_with_deconv_sum("adhoc", &mut model, &inputs)?;
152        model.set_output_outlets(&output)?;
153        model.into_runnable()?.run(tvec![]).context("In adhoc deconvolution eval")
154    }
155}
156
157impl TypedOp for Deconv {
158    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
159        ensure!(inputs.len() == 3);
160        let x_fact = inputs[0];
161        let k_fact = inputs[1];
162        ensure!(
163            &self.pool_spec.input_channels.to_dim()
164                == self.pool_spec.data_format.shape(&inputs[0].shape)?.c()
165        );
166        ensure!(
167            self.pool_spec.input_channels.to_dim()
168                == *self.kernel_format.input_channels(&k_fact.shape, self.group)
169        );
170        let output_shape = super::output_shape(&self.pool_spec, &x_fact.shape, &self.adjustments)?;
171        Ok(tvec!(x_fact.datum_type.fact(&output_shape)))
172    }
173
174    fn axes_mapping(
175        &self,
176        inputs: &[&TypedFact],
177        outputs: &[&TypedFact],
178    ) -> TractResult<AxesMapping> {
179        let fact = &inputs[0];
180        let k_fact = &inputs[1];
181        let shape = self.pool_spec.data_format.shape(&fact.shape)?;
182        let mut axes = AxesMapping::disconnected(inputs, outputs)?
183            .renaming((InOut::In(0), shape.c_axis()), 'I')?
184            .renaming((InOut::Out(0), shape.c_axis()), 'O')?;
185        if let Some(n_axis) = shape.n_axis() {
186            axes = axes
187                .renaming((InOut::In(0), n_axis), 'N')?
188                .linking('N', (InOut::Out(0), n_axis))?;
189        }
190        let h_axis = shape.h_axis();
191        let geo = "HWXYZ".chars().chain('a'..);
192        let kernel_spatial_shape = self.kernel_format.spatial_shape(&k_fact.shape);
193        for ((ix, dim), repr) in kernel_spatial_shape.iter().enumerate().zip(geo) {
194            if dim.is_one()
195                && self.pool_spec.stride(ix) == 1
196                && self.pool_spec.padding.valid_dim(ix, true)
197                && self.adjustments[ix] == 0
198            {
199                axes = axes
200                    .renaming((InOut::In(0), ix + h_axis), repr)?
201                    .linking((InOut::In(0), ix + h_axis), (InOut::Out(0), ix + h_axis))?;
202            }
203        }
204        Ok(axes)
205    }
206
207    fn codegen(
208        &self,
209        model: &TypedModel,
210        node: &TypedNode,
211    ) -> TractResult<Option<TypedModelPatch>> {
212        let mut patch = TypedModelPatch::default();
213        let inputs = patch.taps(model, &node.inputs)?;
214        let output = self
215            .wire_with_deconv_sum(&node.name, &mut patch, &inputs)
216            .context("In wire_with_deconv_sum")?;
217        patch.shunt_outside(model, node.id.into(), output[0])?;
218        Ok(Some(patch))
219    }
220
221    as_op!();
222}