tch_plus/nn/
conv_transpose.rs

1//! A two dimension transposed convolution layer.
2use super::Path;
3use crate::Tensor;
4use std::borrow::Borrow;
5
6/// A generic transposed convolution configuration.
7#[allow(clippy::upper_case_acronyms)]
8#[derive(Debug, Clone, Copy)]
9pub struct ConvTransposeConfigND<ND> {
10    pub stride: ND,
11    pub padding: ND,
12    pub output_padding: ND,
13    pub groups: i64,
14    pub bias: bool,
15    pub dilation: ND,
16    pub ws_init: super::Init,
17    pub bs_init: super::Init,
18}
19
20/// A transposed convolution configuration using the same values on each dimension.
21pub type ConvTransposeConfig = ConvTransposeConfigND<i64>;
22
23impl Default for ConvTransposeConfig {
24    fn default() -> Self {
25        ConvTransposeConfigND {
26            stride: 1,
27            padding: 0,
28            output_padding: 0,
29            dilation: 1,
30            groups: 1,
31            bias: true,
32            ws_init: super::init::DEFAULT_KAIMING_UNIFORM,
33            bs_init: super::Init::Const(0.),
34        }
35    }
36}
37
38/// A generic transposed convolution layer.
39#[allow(clippy::upper_case_acronyms)]
40#[derive(Debug)]
41pub struct ConvTransposeND<ND> {
42    pub ws: Tensor,
43    pub bs: Option<Tensor>,
44    config: ConvTransposeConfigND<ND>,
45}
46
47/// A one dimension transposed convolution layer.
48pub type ConvTranspose1D = ConvTransposeND<[i64; 1]>;
49
50/// A two dimension transposed convolution layer.
51pub type ConvTranspose2D = ConvTransposeND<[i64; 2]>;
52
53/// A three dimension transposed convolution layer.
54pub type ConvTranspose3D = ConvTransposeND<[i64; 3]>;
55
56fn conv_transpose<'a, ND: std::convert::AsRef<[i64]>, T: Borrow<super::Path<'a>>>(
57    vs: T,
58    in_dim: i64,
59    out_dim: i64,
60    ksizes: ND,
61    config: ConvTransposeConfigND<ND>,
62) -> ConvTransposeND<ND> {
63    let vs = vs.borrow();
64    let bs = if config.bias { Some(vs.var("bias", &[out_dim], config.bs_init)) } else { None };
65    let mut weight_size = vec![in_dim, out_dim / config.groups];
66    weight_size.extend(ksizes.as_ref().iter());
67    let ws = vs.var("weight", weight_size.as_slice(), config.ws_init);
68    ConvTransposeND { ws, bs, config }
69}
70
71trait Create: std::convert::AsRef<[i64]> + std::marker::Sized {
72    fn make_array(i: i64) -> Self;
73
74    fn conv_transpose<'a, T: Borrow<super::Path<'a>>>(
75        vs: T,
76        in_dim: i64,
77        out_dim: i64,
78        ksize: i64,
79        config: ConvTransposeConfig,
80    ) -> ConvTransposeND<Self> {
81        let config = ConvTransposeConfigND::<Self> {
82            stride: Self::make_array(config.stride),
83            padding: Self::make_array(config.padding),
84            output_padding: Self::make_array(config.output_padding),
85            dilation: Self::make_array(config.dilation),
86            groups: config.groups,
87            bias: config.bias,
88            ws_init: config.ws_init,
89            bs_init: config.bs_init,
90        };
91        conv_transpose(vs, in_dim, out_dim, Self::make_array(ksize), config)
92    }
93}
94
95impl Create for [i64; 1] {
96    fn make_array(i: i64) -> Self {
97        [i]
98    }
99}
100
101impl Create for [i64; 2] {
102    fn make_array(i: i64) -> Self {
103        [i, i]
104    }
105}
106
107impl Create for [i64; 3] {
108    fn make_array(i: i64) -> Self {
109        [i, i, i]
110    }
111}
112
113/// Creates a one dimension transposed convolution layer.
114pub fn conv_transpose1d<'a, T: Borrow<Path<'a>>>(
115    vs: T,
116    i: i64,
117    o: i64,
118    k: i64,
119    c: ConvTransposeConfig,
120) -> ConvTranspose1D {
121    <[i64; 1]>::conv_transpose(vs, i, o, k, c)
122}
123
124/// Creates a two dimension transposed convolution layer.
125pub fn conv_transpose2d<'a, T: Borrow<Path<'a>>>(
126    vs: T,
127    i: i64,
128    o: i64,
129    k: i64,
130    c: ConvTransposeConfig,
131) -> ConvTranspose2D {
132    <[i64; 2]>::conv_transpose(vs, i, o, k, c)
133}
134
135/// Creates a three dimension transposed convolution layer.
136pub fn conv_transpose3d<'a, T: Borrow<Path<'a>>>(
137    vs: T,
138    i: i64,
139    o: i64,
140    k: i64,
141    c: ConvTransposeConfig,
142) -> ConvTranspose3D {
143    <[i64; 3]>::conv_transpose(vs, i, o, k, c)
144}
145
146impl super::module::Module for ConvTranspose1D {
147    fn forward(&self, xs: &Tensor) -> Tensor {
148        Tensor::conv_transpose1d(
149            xs,
150            &self.ws,
151            self.bs.as_ref(),
152            self.config.stride,
153            self.config.padding,
154            self.config.output_padding,
155            self.config.groups,
156            self.config.dilation,
157        )
158    }
159}
160
161impl super::module::Module for ConvTranspose2D {
162    fn forward(&self, xs: &Tensor) -> Tensor {
163        Tensor::conv_transpose2d(
164            xs,
165            &self.ws,
166            self.bs.as_ref(),
167            self.config.stride,
168            self.config.padding,
169            self.config.output_padding,
170            self.config.groups,
171            self.config.dilation,
172        )
173    }
174}
175
176impl super::module::Module for ConvTranspose3D {
177    fn forward(&self, xs: &Tensor) -> Tensor {
178        Tensor::conv_transpose3d(
179            xs,
180            &self.ws,
181            self.bs.as_ref(),
182            self.config.stride,
183            self.config.padding,
184            self.config.output_padding,
185            self.config.groups,
186            self.config.dilation,
187        )
188    }
189}