tch_plus/nn/
conv.rs

1//! N-dimensional convolution layers.
2use super::Path;
3use crate::{TchError, Tensor};
4use std::borrow::Borrow;
5
6/// How padding is performed by convolution operations
7/// on the edge of the input tensor.
8#[derive(Debug, Clone, Copy)]
9pub enum PaddingMode {
10    Zeros,
11    Reflect,
12    Replicate,
13    Circular,
14}
15
16impl PaddingMode {
17    fn to_string(self) -> &'static str {
18        // This has to match the internal representation used on the C++
19        // side.
20        match self {
21            // The default value when using constant is zero.
22            PaddingMode::Zeros => "constant",
23            PaddingMode::Reflect => "reflect",
24            PaddingMode::Replicate => "replicate",
25            PaddingMode::Circular => "circular",
26        }
27    }
28
29    pub fn f_pad(
30        self,
31        xs: &Tensor,
32        reversed_padding_repeated_twice: &[i64],
33    ) -> Result<Tensor, TchError> {
34        xs.f_pad(reversed_padding_repeated_twice, self.to_string(), None)
35    }
36
37    pub fn pad(self, xs: &Tensor, reversed_padding_repeated_twice: &[i64]) -> Tensor {
38        xs.pad(reversed_padding_repeated_twice, self.to_string(), None)
39    }
40}
41
42/// Generic convolution config.
43#[allow(clippy::upper_case_acronyms)]
44#[derive(Debug, Clone, Copy)]
45pub struct ConvConfigND<ND> {
46    pub stride: ND,
47    pub padding: ND,
48    pub dilation: ND,
49    pub groups: i64,
50    pub bias: bool,
51    pub ws_init: super::Init,
52    pub bs_init: super::Init,
53    pub padding_mode: PaddingMode,
54}
55
56/// Convolution config using the same parameters on all dimensions.
57pub type ConvConfig = ConvConfigND<i64>;
58
59impl Default for ConvConfig {
60    fn default() -> Self {
61        ConvConfig {
62            stride: 1,
63            padding: 0,
64            dilation: 1,
65            groups: 1,
66            bias: true,
67            ws_init: super::init::DEFAULT_KAIMING_UNIFORM,
68            bs_init: super::Init::Const(0.),
69            padding_mode: PaddingMode::Zeros,
70        }
71    }
72}
73
74impl Default for ConvConfigND<[i64; 2]> {
75    fn default() -> Self {
76        ConvConfigND::<[i64; 2]> {
77            stride: [1, 1],
78            padding: [0, 0],
79            dilation: [1, 1],
80            groups: 1,
81            bias: true,
82            ws_init: super::init::DEFAULT_KAIMING_UNIFORM,
83            bs_init: super::Init::Const(0.),
84            padding_mode: PaddingMode::Zeros,
85        }
86    }
87}
88
89/// The default convolution config without bias.
90pub fn no_bias() -> ConvConfig {
91    ConvConfig { bias: false, ..Default::default() }
92}
93
94// Use const generics when they have landed in stable rust.
95/// A N-dimensional convolution layer.
96#[derive(Debug)]
97pub struct Conv<ND> {
98    pub ws: Tensor,
99    pub bs: Option<Tensor>,
100    reversed_padding_repeated_twice: Vec<i64>,
101    config: ConvConfigND<ND>,
102}
103
104/// One dimension convolution layer.
105pub type Conv1D = Conv<[i64; 1]>;
106
107/// Two dimensions convolution layer.
108pub type Conv2D = Conv<[i64; 2]>;
109
110/// Three dimensions convolution layer.
111pub type Conv3D = Conv<[i64; 3]>;
112
113/// Creates a new convolution layer for any number of dimensions.
114pub fn conv<'a, ND: std::convert::AsRef<[i64]>, T: Borrow<super::Path<'a>>>(
115    vs: T,
116    in_dim: i64,
117    out_dim: i64,
118    ksizes: ND,
119    config: ConvConfigND<ND>,
120) -> Conv<ND> {
121    let vs = vs.borrow();
122    let bs = if config.bias { Some(vs.var("bias", &[out_dim], config.bs_init)) } else { None };
123    let mut weight_size = vec![out_dim, in_dim / config.groups];
124    weight_size.extend(ksizes.as_ref().iter());
125    let ws = vs.var("weight", weight_size.as_slice(), config.ws_init);
126    let mut reversed_padding_repeated_twice = vec![];
127    for &v in config.padding.as_ref().iter().rev() {
128        reversed_padding_repeated_twice.push(v)
129    }
130    for &v in config.padding.as_ref().iter().rev() {
131        reversed_padding_repeated_twice.push(v)
132    }
133    Conv { ws, bs, config, reversed_padding_repeated_twice }
134}
135
136trait Create: std::convert::AsRef<[i64]> + std::marker::Sized {
137    fn make_array(i: i64) -> Self;
138
139    fn conv<'a, T: Borrow<super::Path<'a>>>(
140        vs: T,
141        in_dim: i64,
142        out_dim: i64,
143        ksize: i64,
144        config: ConvConfig,
145    ) -> Conv<Self> {
146        let config = ConvConfigND::<Self> {
147            stride: Self::make_array(config.stride),
148            padding: Self::make_array(config.padding),
149            dilation: Self::make_array(config.dilation),
150            groups: config.groups,
151            bias: config.bias,
152            ws_init: config.ws_init,
153            bs_init: config.bs_init,
154            padding_mode: config.padding_mode,
155        };
156        conv(vs, in_dim, out_dim, Self::make_array(ksize), config)
157    }
158}
159
160impl Create for [i64; 1] {
161    fn make_array(i: i64) -> Self {
162        [i]
163    }
164}
165
166impl Create for [i64; 2] {
167    fn make_array(i: i64) -> Self {
168        [i, i]
169    }
170}
171
172impl Create for [i64; 3] {
173    fn make_array(i: i64) -> Self {
174        [i, i, i]
175    }
176}
177
178/// Creates a new one dimension convolution layer.
179pub fn conv1d<'a, T: Borrow<Path<'a>>>(vs: T, i: i64, o: i64, k: i64, c: ConvConfig) -> Conv1D {
180    <[i64; 1]>::conv(vs, i, o, k, c)
181}
182
183/// Creates a new two dimension convolution layer.
184pub fn conv2d<'a, T: Borrow<Path<'a>>>(vs: T, i: i64, o: i64, k: i64, c: ConvConfig) -> Conv2D {
185    <[i64; 2]>::conv(vs, i, o, k, c)
186}
187
188/// Creates a new three dimension convolution layer.
189pub fn conv3d<'a, T: Borrow<Path<'a>>>(vs: T, i: i64, o: i64, k: i64, c: ConvConfig) -> Conv3D {
190    <[i64; 3]>::conv(vs, i, o, k, c)
191}
192
193impl super::module::Module for Conv1D {
194    fn forward(&self, xs: &Tensor) -> Tensor {
195        let (xs, padding) = match self.config.padding_mode {
196            PaddingMode::Zeros => (xs.shallow_clone(), self.config.padding),
197            p => (p.pad(xs, &self.reversed_padding_repeated_twice), [0]),
198        };
199        xs.conv1d(
200            &self.ws,
201            self.bs.as_ref(),
202            self.config.stride,
203            padding,
204            self.config.dilation,
205            self.config.groups,
206        )
207    }
208}
209
210impl super::module::Module for Conv2D {
211    fn forward(&self, xs: &Tensor) -> Tensor {
212        let (xs, padding) = match self.config.padding_mode {
213            PaddingMode::Zeros => (xs.shallow_clone(), self.config.padding),
214            p => (p.pad(xs, &self.reversed_padding_repeated_twice), [0, 0]),
215        };
216        xs.conv2d(
217            &self.ws,
218            self.bs.as_ref(),
219            self.config.stride,
220            padding,
221            self.config.dilation,
222            self.config.groups,
223        )
224    }
225}
226
227impl super::module::Module for Conv3D {
228    fn forward(&self, xs: &Tensor) -> Tensor {
229        let (xs, padding) = match self.config.padding_mode {
230            PaddingMode::Zeros => (xs.shallow_clone(), self.config.padding),
231            p => (p.pad(xs, &self.reversed_padding_repeated_twice), [0, 0, 0]),
232        };
233        xs.conv3d(
234            &self.ws,
235            self.bs.as_ref(),
236            self.config.stride,
237            padding,
238            self.config.dilation,
239            self.config.groups,
240        )
241    }
242}