tensor_rs/tensor_trait/
convolution.rs

1use crate::tensor::PaddingMode;
2
3pub trait Convolution where Self: std::marker::Sized {
4
5    fn conv2d(&self, filter: &Self,
6                  stride: (usize, usize),
7                  padding: (usize, usize),
8                  dilation: (usize, usize),
9                  padding_mode: PaddingMode
10    ) -> Self;
11
12    fn conv2d_grad(&self, filter: &Self,
13                       stride: (usize, usize),
14                       padding: (usize, usize),
15                       dilation: (usize, usize),
16                       padding_mode: PaddingMode,
17                       output_grad: &Self
18    ) -> (Self, Self);
19
20    fn conv_gen(&self, filter: &Self,
21                    stride: &[usize],
22                    padding: &[usize],
23                    dilation: &[usize],
24                    padding_mode: PaddingMode
25    ) -> Self;
26
27    fn conv_grad_gen(&self, filter: &Self,
28                         stride: &[usize],
29                         padding: &[usize],
30                         dilation: &[usize],
31                         padding_mode: PaddingMode,
32                         output_grad: &Self,
33    ) -> (Self, Self);
34}