tract_core/ops/cnn/deconv/
mod.rs1use crate::internal::*;
2use crate::ops::cnn::{PaddingSpec, PoolSpec};
3
4#[allow(clippy::module_inception)]
5mod deconv;
6mod deconv_sum;
7
8pub use deconv::Deconv;
9
10pub fn output_shape<D: DimLike>(
11 pool_spec: &PoolSpec,
12 x_shape: &[D],
13 adjustments: &[usize],
14) -> TractResult<TVec<D>> {
15 let x_shape = pool_spec.data_format.shape(x_shape)?;
16 let spatial_input_shape = x_shape.hw_dims();
17 let spatial_output_details = pool_spec.padding.compute_for_deconv(
18 spatial_input_shape,
19 &pool_spec.kernel_shape,
20 &pool_spec.dilations(),
21 &pool_spec.strides(),
22 adjustments,
23 )?;
24 let deconv_shape: TVec<D> =
25 spatial_output_details.iter().map(|comp| comp.deconvoluted.clone()).collect();
26 let output_shape = pool_spec.data_format.from_n_c_hw(
27 x_shape.n().cloned().unwrap_or_else(|| 1.into()),
28 pool_spec.output_channels.into(),
29 deconv_shape,
30 )?;
31 Ok(output_shape.shape)
32}
33
34pub fn adjustments(
35 pool_spec: &PoolSpec,
36 input_geo: &[usize],
37 output_geo: &[usize],
38) -> TractResult<TVec<usize>> {
39 debug_assert_eq!(pool_spec.rank(), pool_spec.strides().len());
40 debug_assert_eq!(pool_spec.rank(), pool_spec.dilations().len());
41 debug_assert_eq!(pool_spec.rank(), pool_spec.kernel_shape.len());
42 debug_assert_eq!(pool_spec.rank(), input_geo.len());
43 debug_assert_eq!(pool_spec.rank(), output_geo.len());
44 let rank = pool_spec.rank();
45 let pad: TVec<usize> = match &pool_spec.padding {
46 PaddingSpec::Explicit(beg, end) => (0..rank).map(|r| beg[r] + end[r]).collect(),
47 PaddingSpec::Valid => tvec!(0; rank),
48 pad => todo!("Unsupported padding in deconvolution arguments {pad:?}"),
49 };
50 tract_itertools::izip!(
51 input_geo,
52 &pool_spec.kernel_shape,
53 output_geo,
54 pool_spec.strides().as_ref(),
55 pool_spec.dilations().as_ref(),
56 pad,
57 )
58 .map(|(x, k, y, s, d, p)| {
59 let adj = y.to_usize()? + p - s * (x.to_usize()? - 1) - (k.to_usize()? - 1) * d - 1;
60 Ok(adj)
61 })
62 .collect::<TractResult<TVec<usize>>>()
63}