1use super::Path;
3use crate::Tensor;
4use std::borrow::Borrow;
5
6#[allow(clippy::upper_case_acronyms)]
8#[derive(Debug, Clone, Copy)]
9pub struct ConvTransposeConfigND<ND> {
10 pub stride: ND,
11 pub padding: ND,
12 pub output_padding: ND,
13 pub groups: i64,
14 pub bias: bool,
15 pub dilation: ND,
16 pub ws_init: super::Init,
17 pub bs_init: super::Init,
18}
19
20pub type ConvTransposeConfig = ConvTransposeConfigND<i64>;
22
23impl Default for ConvTransposeConfig {
24 fn default() -> Self {
25 ConvTransposeConfigND {
26 stride: 1,
27 padding: 0,
28 output_padding: 0,
29 dilation: 1,
30 groups: 1,
31 bias: true,
32 ws_init: super::init::DEFAULT_KAIMING_UNIFORM,
33 bs_init: super::Init::Const(0.),
34 }
35 }
36}
37
38#[allow(clippy::upper_case_acronyms)]
40#[derive(Debug)]
41pub struct ConvTransposeND<ND> {
42 pub ws: Tensor,
43 pub bs: Option<Tensor>,
44 config: ConvTransposeConfigND<ND>,
45}
46
47pub type ConvTranspose1D = ConvTransposeND<[i64; 1]>;
49
50pub type ConvTranspose2D = ConvTransposeND<[i64; 2]>;
52
53pub type ConvTranspose3D = ConvTransposeND<[i64; 3]>;
55
56fn conv_transpose<'a, ND: std::convert::AsRef<[i64]>, T: Borrow<super::Path<'a>>>(
57 vs: T,
58 in_dim: i64,
59 out_dim: i64,
60 ksizes: ND,
61 config: ConvTransposeConfigND<ND>,
62) -> ConvTransposeND<ND> {
63 let vs = vs.borrow();
64 let bs = if config.bias { Some(vs.var("bias", &[out_dim], config.bs_init)) } else { None };
65 let mut weight_size = vec![in_dim, out_dim / config.groups];
66 weight_size.extend(ksizes.as_ref().iter());
67 let ws = vs.var("weight", weight_size.as_slice(), config.ws_init);
68 ConvTransposeND { ws, bs, config }
69}
70
71trait Create: std::convert::AsRef<[i64]> + std::marker::Sized {
72 fn make_array(i: i64) -> Self;
73
74 fn conv_transpose<'a, T: Borrow<super::Path<'a>>>(
75 vs: T,
76 in_dim: i64,
77 out_dim: i64,
78 ksize: i64,
79 config: ConvTransposeConfig,
80 ) -> ConvTransposeND<Self> {
81 let config = ConvTransposeConfigND::<Self> {
82 stride: Self::make_array(config.stride),
83 padding: Self::make_array(config.padding),
84 output_padding: Self::make_array(config.output_padding),
85 dilation: Self::make_array(config.dilation),
86 groups: config.groups,
87 bias: config.bias,
88 ws_init: config.ws_init,
89 bs_init: config.bs_init,
90 };
91 conv_transpose(vs, in_dim, out_dim, Self::make_array(ksize), config)
92 }
93}
94
95impl Create for [i64; 1] {
96 fn make_array(i: i64) -> Self {
97 [i]
98 }
99}
100
101impl Create for [i64; 2] {
102 fn make_array(i: i64) -> Self {
103 [i, i]
104 }
105}
106
107impl Create for [i64; 3] {
108 fn make_array(i: i64) -> Self {
109 [i, i, i]
110 }
111}
112
113pub fn conv_transpose1d<'a, T: Borrow<Path<'a>>>(
115 vs: T,
116 i: i64,
117 o: i64,
118 k: i64,
119 c: ConvTransposeConfig,
120) -> ConvTranspose1D {
121 <[i64; 1]>::conv_transpose(vs, i, o, k, c)
122}
123
124pub fn conv_transpose2d<'a, T: Borrow<Path<'a>>>(
126 vs: T,
127 i: i64,
128 o: i64,
129 k: i64,
130 c: ConvTransposeConfig,
131) -> ConvTranspose2D {
132 <[i64; 2]>::conv_transpose(vs, i, o, k, c)
133}
134
135pub fn conv_transpose3d<'a, T: Borrow<Path<'a>>>(
137 vs: T,
138 i: i64,
139 o: i64,
140 k: i64,
141 c: ConvTransposeConfig,
142) -> ConvTranspose3D {
143 <[i64; 3]>::conv_transpose(vs, i, o, k, c)
144}
145
146impl super::module::Module for ConvTranspose1D {
147 fn forward(&self, xs: &Tensor) -> Tensor {
148 Tensor::conv_transpose1d(
149 xs,
150 &self.ws,
151 self.bs.as_ref(),
152 self.config.stride,
153 self.config.padding,
154 self.config.output_padding,
155 self.config.groups,
156 self.config.dilation,
157 )
158 }
159}
160
161impl super::module::Module for ConvTranspose2D {
162 fn forward(&self, xs: &Tensor) -> Tensor {
163 Tensor::conv_transpose2d(
164 xs,
165 &self.ws,
166 self.bs.as_ref(),
167 self.config.stride,
168 self.config.padding,
169 self.config.output_padding,
170 self.config.groups,
171 self.config.dilation,
172 )
173 }
174}
175
176impl super::module::Module for ConvTranspose3D {
177 fn forward(&self, xs: &Tensor) -> Tensor {
178 Tensor::conv_transpose3d(
179 xs,
180 &self.ws,
181 self.bs.as_ref(),
182 self.config.stride,
183 self.config.padding,
184 self.config.output_padding,
185 self.config.groups,
186 self.config.dilation,
187 )
188 }
189}