tract_core/ops/cnn/deconv/
mod.rs

1use crate::internal::*;
2use crate::ops::cnn::{PaddingSpec, PoolSpec};
3
4#[allow(clippy::module_inception)]
5mod deconv;
6mod deconv_sum;
7
8pub use deconv::Deconv;
9
10pub fn output_shape<D: DimLike>(
11    pool_spec: &PoolSpec,
12    x_shape: &[D],
13    adjustments: &[usize],
14) -> TractResult<TVec<D>> {
15    let x_shape = pool_spec.data_format.shape(x_shape)?;
16    let spatial_input_shape = x_shape.hw_dims();
17    let spatial_output_details = pool_spec.padding.compute_for_deconv(
18        spatial_input_shape,
19        &pool_spec.kernel_shape,
20        &pool_spec.dilations(),
21        &pool_spec.strides(),
22        adjustments,
23    )?;
24    let deconv_shape: TVec<D> =
25        spatial_output_details.iter().map(|comp| comp.deconvoluted.clone()).collect();
26    let output_shape = pool_spec.data_format.from_n_c_hw(
27        x_shape.n().cloned().unwrap_or_else(|| 1.into()),
28        pool_spec.output_channels.into(),
29        deconv_shape,
30    )?;
31    Ok(output_shape.shape)
32}
33
34pub fn adjustments(
35    pool_spec: &PoolSpec,
36    input_geo: &[usize],
37    output_geo: &[usize],
38) -> TractResult<TVec<usize>> {
39    debug_assert_eq!(pool_spec.rank(), pool_spec.strides().len());
40    debug_assert_eq!(pool_spec.rank(), pool_spec.dilations().len());
41    debug_assert_eq!(pool_spec.rank(), pool_spec.kernel_shape.len());
42    debug_assert_eq!(pool_spec.rank(), input_geo.len());
43    debug_assert_eq!(pool_spec.rank(), output_geo.len());
44    let rank = pool_spec.rank();
45    let pad: TVec<usize> = match &pool_spec.padding {
46        PaddingSpec::Explicit(beg, end) => (0..rank).map(|r| beg[r] + end[r]).collect(),
47        PaddingSpec::Valid => tvec!(0; rank),
48        pad => todo!("Unsupported padding in deconvolution arguments {pad:?}"),
49    };
50    tract_itertools::izip!(
51        input_geo,
52        &pool_spec.kernel_shape,
53        output_geo,
54        pool_spec.strides().as_ref(),
55        pool_spec.dilations().as_ref(),
56        pad,
57    )
58    .map(|(x, k, y, s, d, p)| {
59        let adj = y.to_usize()? + p - s * (x.to_usize()? - 1) - (k.to_usize()? - 1) * d - 1;
60        Ok(adj)
61    })
62    .collect::<TractResult<TVec<usize>>>()
63}