1use super::Path;
3use crate::{TchError, Tensor};
4use std::borrow::Borrow;
5
6#[derive(Debug, Clone, Copy)]
9pub enum PaddingMode {
10 Zeros,
11 Reflect,
12 Replicate,
13 Circular,
14}
15
16impl PaddingMode {
17 fn to_string(self) -> &'static str {
18 match self {
21 PaddingMode::Zeros => "constant",
23 PaddingMode::Reflect => "reflect",
24 PaddingMode::Replicate => "replicate",
25 PaddingMode::Circular => "circular",
26 }
27 }
28
29 pub fn f_pad(
30 self,
31 xs: &Tensor,
32 reversed_padding_repeated_twice: &[i64],
33 ) -> Result<Tensor, TchError> {
34 xs.f_pad(reversed_padding_repeated_twice, self.to_string(), None)
35 }
36
37 pub fn pad(self, xs: &Tensor, reversed_padding_repeated_twice: &[i64]) -> Tensor {
38 xs.pad(reversed_padding_repeated_twice, self.to_string(), None)
39 }
40}
41
42#[allow(clippy::upper_case_acronyms)]
44#[derive(Debug, Clone, Copy)]
45pub struct ConvConfigND<ND> {
46 pub stride: ND,
47 pub padding: ND,
48 pub dilation: ND,
49 pub groups: i64,
50 pub bias: bool,
51 pub ws_init: super::Init,
52 pub bs_init: super::Init,
53 pub padding_mode: PaddingMode,
54}
55
56pub type ConvConfig = ConvConfigND<i64>;
58
59impl Default for ConvConfig {
60 fn default() -> Self {
61 ConvConfig {
62 stride: 1,
63 padding: 0,
64 dilation: 1,
65 groups: 1,
66 bias: true,
67 ws_init: super::init::DEFAULT_KAIMING_UNIFORM,
68 bs_init: super::Init::Const(0.),
69 padding_mode: PaddingMode::Zeros,
70 }
71 }
72}
73
74impl Default for ConvConfigND<[i64; 2]> {
75 fn default() -> Self {
76 ConvConfigND::<[i64; 2]> {
77 stride: [1, 1],
78 padding: [0, 0],
79 dilation: [1, 1],
80 groups: 1,
81 bias: true,
82 ws_init: super::init::DEFAULT_KAIMING_UNIFORM,
83 bs_init: super::Init::Const(0.),
84 padding_mode: PaddingMode::Zeros,
85 }
86 }
87}
88
89pub fn no_bias() -> ConvConfig {
91 ConvConfig { bias: false, ..Default::default() }
92}
93
94#[derive(Debug)]
97pub struct Conv<ND> {
98 pub ws: Tensor,
99 pub bs: Option<Tensor>,
100 reversed_padding_repeated_twice: Vec<i64>,
101 config: ConvConfigND<ND>,
102}
103
104pub type Conv1D = Conv<[i64; 1]>;
106
107pub type Conv2D = Conv<[i64; 2]>;
109
110pub type Conv3D = Conv<[i64; 3]>;
112
113pub fn conv<'a, ND: std::convert::AsRef<[i64]>, T: Borrow<super::Path<'a>>>(
115 vs: T,
116 in_dim: i64,
117 out_dim: i64,
118 ksizes: ND,
119 config: ConvConfigND<ND>,
120) -> Conv<ND> {
121 let vs = vs.borrow();
122 let bs = if config.bias { Some(vs.var("bias", &[out_dim], config.bs_init)) } else { None };
123 let mut weight_size = vec![out_dim, in_dim / config.groups];
124 weight_size.extend(ksizes.as_ref().iter());
125 let ws = vs.var("weight", weight_size.as_slice(), config.ws_init);
126 let mut reversed_padding_repeated_twice = vec![];
127 for &v in config.padding.as_ref().iter().rev() {
128 reversed_padding_repeated_twice.push(v)
129 }
130 for &v in config.padding.as_ref().iter().rev() {
131 reversed_padding_repeated_twice.push(v)
132 }
133 Conv { ws, bs, config, reversed_padding_repeated_twice }
134}
135
136trait Create: std::convert::AsRef<[i64]> + std::marker::Sized {
137 fn make_array(i: i64) -> Self;
138
139 fn conv<'a, T: Borrow<super::Path<'a>>>(
140 vs: T,
141 in_dim: i64,
142 out_dim: i64,
143 ksize: i64,
144 config: ConvConfig,
145 ) -> Conv<Self> {
146 let config = ConvConfigND::<Self> {
147 stride: Self::make_array(config.stride),
148 padding: Self::make_array(config.padding),
149 dilation: Self::make_array(config.dilation),
150 groups: config.groups,
151 bias: config.bias,
152 ws_init: config.ws_init,
153 bs_init: config.bs_init,
154 padding_mode: config.padding_mode,
155 };
156 conv(vs, in_dim, out_dim, Self::make_array(ksize), config)
157 }
158}
159
160impl Create for [i64; 1] {
161 fn make_array(i: i64) -> Self {
162 [i]
163 }
164}
165
166impl Create for [i64; 2] {
167 fn make_array(i: i64) -> Self {
168 [i, i]
169 }
170}
171
172impl Create for [i64; 3] {
173 fn make_array(i: i64) -> Self {
174 [i, i, i]
175 }
176}
177
178pub fn conv1d<'a, T: Borrow<Path<'a>>>(vs: T, i: i64, o: i64, k: i64, c: ConvConfig) -> Conv1D {
180 <[i64; 1]>::conv(vs, i, o, k, c)
181}
182
183pub fn conv2d<'a, T: Borrow<Path<'a>>>(vs: T, i: i64, o: i64, k: i64, c: ConvConfig) -> Conv2D {
185 <[i64; 2]>::conv(vs, i, o, k, c)
186}
187
188pub fn conv3d<'a, T: Borrow<Path<'a>>>(vs: T, i: i64, o: i64, k: i64, c: ConvConfig) -> Conv3D {
190 <[i64; 3]>::conv(vs, i, o, k, c)
191}
192
193impl super::module::Module for Conv1D {
194 fn forward(&self, xs: &Tensor) -> Tensor {
195 let (xs, padding) = match self.config.padding_mode {
196 PaddingMode::Zeros => (xs.shallow_clone(), self.config.padding),
197 p => (p.pad(xs, &self.reversed_padding_repeated_twice), [0]),
198 };
199 xs.conv1d(
200 &self.ws,
201 self.bs.as_ref(),
202 self.config.stride,
203 padding,
204 self.config.dilation,
205 self.config.groups,
206 )
207 }
208}
209
210impl super::module::Module for Conv2D {
211 fn forward(&self, xs: &Tensor) -> Tensor {
212 let (xs, padding) = match self.config.padding_mode {
213 PaddingMode::Zeros => (xs.shallow_clone(), self.config.padding),
214 p => (p.pad(xs, &self.reversed_padding_repeated_twice), [0, 0]),
215 };
216 xs.conv2d(
217 &self.ws,
218 self.bs.as_ref(),
219 self.config.stride,
220 padding,
221 self.config.dilation,
222 self.config.groups,
223 )
224 }
225}
226
227impl super::module::Module for Conv3D {
228 fn forward(&self, xs: &Tensor) -> Tensor {
229 let (xs, padding) = match self.config.padding_mode {
230 PaddingMode::Zeros => (xs.shallow_clone(), self.config.padding),
231 p => (p.pad(xs, &self.reversed_padding_repeated_twice), [0, 0, 0]),
232 };
233 xs.conv3d(
234 &self.ws,
235 self.bs.as_ref(),
236 self.config.stride,
237 padding,
238 self.config.dilation,
239 self.config.groups,
240 )
241 }
242}