tract_core/ops/cnn/deconv/
deconv.rs1use crate::internal::*;
2use crate::ops::array::MultiBroadcastTo;
3use crate::ops::cnn::wire_reshape_bias_for_bin;
4use crate::ops::cnn::KernelFormat;
5use crate::ops::cnn::PoolSpec;
6use crate::ops::einsum::EinSum;
7
8#[derive(Clone, Debug, new, Hash)]
9pub struct Deconv {
10 pub pool_spec: PoolSpec,
11 pub kernel_format: KernelFormat,
12 pub adjustments: TVec<usize>,
13 pub group: usize,
14}
15
16impl Deconv {
17 fn wire_with_deconv_sum(
18 &self,
19 name: &str,
20 target: &mut TypedModel,
21 inputs: &[OutletId],
22 ) -> TractResult<TVec<OutletId>> {
23 let input_shape = target.outlet_fact(inputs[0])?.shape.clone();
24 let shape = self.pool_spec.data_format.shape(input_shape.to_tvec())?;
25 let geo_dim = shape.hw_dims().iter().product();
26
27 let mut input = target.wire_node(
29 format!("{name}.reshaped_input"),
30 AxisOp::Reshape(shape.h_axis(), shape.hw_dims().into(), tvec!(geo_dim)),
31 &[inputs[0]],
32 )?;
33
34 if self.group != 1 {
36 let i_axis = self.pool_spec.data_format.has_n() as usize
38 + self.pool_spec.data_format.c_is_last() as usize;
39 let i_dim = target.outlet_fact(input[0])?.shape[i_axis].clone();
40 input = target.wire_node(
41 format!("{name}.reshaped_input_for_group"),
42 AxisOp::Reshape(
43 i_axis,
44 tvec![i_dim.clone()],
45 tvec!(self.group.to_dim(), i_dim / self.group),
46 ),
47 &input,
48 )?;
49 if self.pool_spec.data_format.c_is_last() {
50 input = target.wire_node(
51 format!("{name}.group_axis_left"),
52 AxisOp::Move(
53 self.pool_spec.data_format.has_n() as usize + 1,
54 self.pool_spec.data_format.has_n() as usize,
55 ),
56 &input,
57 )?;
58 }
59 }
60
61 let mut kernel = tvec!(inputs[1]);
62 let kernel_fact = target.outlet_fact(kernel[0])?.clone();
63 for (ix, op) in self
64 .kernel_format
65 .kernel_as_group_o_i_hw_ops(&kernel_fact.shape, self.group)
66 .into_iter()
67 .enumerate()
68 {
69 kernel = target.wire_node(format!("{name}.kernel.{ix}"), op, &kernel)?;
70 }
71
72 kernel = target.wire_node(format!("{name}.kernel.mv_i"), AxisOp::Move(2, 3), &kernel)?;
73 kernel =
74 AxisOp::wire_collapse_axis(target, format!("{name}.kernel.col_ohw"), kernel[0], 1)?;
75 if self.group == 1 {
76 kernel = target.wire_node(format!("{name}.kernel.rm_g"), AxisOp::Rm(0), &kernel)?;
77 }
78 let mut expr = if self.pool_spec.data_format.c_is_last() {
79 "gmk,Ngnk->Ngmn".to_string()
80 } else {
81 "gmk,Ngkn->Ngmn".to_string()
82 };
83 if !self.pool_spec.data_format.has_n() {
84 expr = expr.replace('N', "");
85 }
86 if self.group == 1 {
87 expr = expr.replace('g', "");
88 }
89 let einsum = target.wire_node(
90 format!("{name}.einsum"),
91 EinSum { axes: expr.parse()?, operating_dt: kernel_fact.datum_type, q_params: None },
92 &[kernel[0], input[0]],
93 )?;
94
95 let mut bias = wire_reshape_bias_for_bin(
96 target,
97 format!("{name}.reshape_bias"),
98 inputs[2],
99 shape.rank(),
100 shape.c_axis(),
101 self.pool_spec.output_channels,
102 )?[0];
103 let output_shape = super::output_shape(&self.pool_spec, &shape.shape, &self.adjustments)?;
104 bias = target.wire_node(
105 format!("{name}.broadcast_bias"),
106 MultiBroadcastTo { shape: output_shape.into() },
107 &[bias],
108 )?[0];
109
110 let deconv_sum = target.wire_node(
112 format!("{name}.deconv_sum"),
113 super::deconv_sum::DeconvSum::new(
114 self.pool_spec.clone(),
115 self.kernel_format,
116 input_shape,
117 self.adjustments.clone(),
118 self.group,
119 ),
120 &[einsum[0], bias],
121 )?;
122 Ok(deconv_sum)
123 }
124}
125
126impl Op for Deconv {
127 fn name(&self) -> Cow<str> {
128 "Deconv".into()
129 }
130
131 fn info(&self) -> TractResult<Vec<String>> {
132 Ok(vec![format!("{:?}", self.pool_spec)])
133 }
134
135 op_as_typed_op!();
136}
137
138impl EvalOp for Deconv {
139 fn is_stateless(&self) -> bool {
140 true
141 }
142
143 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
144 ensure!(inputs.len() == 3);
145 let mut model = TypedModel::default();
146 let inputs = inputs
147 .into_iter()
148 .enumerate()
149 .map(|(ix, input)| model.add_const(format!("s{ix}"), input.into_tensor()))
150 .collect::<TractResult<TVec<OutletId>>>()?;
151 let output = self.wire_with_deconv_sum("adhoc", &mut model, &inputs)?;
152 model.set_output_outlets(&output)?;
153 model.into_runnable()?.run(tvec![]).context("In adhoc deconvolution eval")
154 }
155}
156
157impl TypedOp for Deconv {
158 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
159 ensure!(inputs.len() == 3);
160 let x_fact = inputs[0];
161 let k_fact = inputs[1];
162 ensure!(
163 &self.pool_spec.input_channels.to_dim()
164 == self.pool_spec.data_format.shape(&inputs[0].shape)?.c()
165 );
166 ensure!(
167 self.pool_spec.input_channels.to_dim()
168 == *self.kernel_format.input_channels(&k_fact.shape, self.group)
169 );
170 let output_shape = super::output_shape(&self.pool_spec, &x_fact.shape, &self.adjustments)?;
171 Ok(tvec!(x_fact.datum_type.fact(&output_shape)))
172 }
173
174 fn axes_mapping(
175 &self,
176 inputs: &[&TypedFact],
177 outputs: &[&TypedFact],
178 ) -> TractResult<AxesMapping> {
179 let fact = &inputs[0];
180 let k_fact = &inputs[1];
181 let shape = self.pool_spec.data_format.shape(&fact.shape)?;
182 let mut axes = AxesMapping::disconnected(inputs, outputs)?
183 .renaming((InOut::In(0), shape.c_axis()), 'I')?
184 .renaming((InOut::Out(0), shape.c_axis()), 'O')?;
185 if let Some(n_axis) = shape.n_axis() {
186 axes = axes
187 .renaming((InOut::In(0), n_axis), 'N')?
188 .linking('N', (InOut::Out(0), n_axis))?;
189 }
190 let h_axis = shape.h_axis();
191 let geo = "HWXYZ".chars().chain('a'..);
192 let kernel_spatial_shape = self.kernel_format.spatial_shape(&k_fact.shape);
193 for ((ix, dim), repr) in kernel_spatial_shape.iter().enumerate().zip(geo) {
194 if dim.is_one()
195 && self.pool_spec.stride(ix) == 1
196 && self.pool_spec.padding.valid_dim(ix, true)
197 && self.adjustments[ix] == 0
198 {
199 axes = axes
200 .renaming((InOut::In(0), ix + h_axis), repr)?
201 .linking((InOut::In(0), ix + h_axis), (InOut::Out(0), ix + h_axis))?;
202 }
203 }
204 Ok(axes)
205 }
206
207 fn codegen(
208 &self,
209 model: &TypedModel,
210 node: &TypedNode,
211 ) -> TractResult<Option<TypedModelPatch>> {
212 let mut patch = TypedModelPatch::default();
213 let inputs = patch.taps(model, &node.inputs)?;
214 let output = self
215 .wire_with_deconv_sum(&node.name, &mut patch, &inputs)
216 .context("In wire_with_deconv_sum")?;
217 patch.shunt_outside(model, node.id.into(), output[0])?;
218 Ok(Some(patch))
219 }
220
221 as_op!();
222}