tract_tensorflow/ops/nn/
conv2d.rs

1use tract_hir::internal::*;
2use tract_hir::ops::cnn;
3use tract_hir::ops::nn::DataFormat;
4
5use crate::model::ParsingContext;
6use crate::tfpb::tensorflow::NodeDef;
7
8pub fn conv2d(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
9    let strides = super::strides(pb)?;
10    let mut op =
11        cnn::Conv::default().hwio().padding(super::padding(pb)?).strides(strides[1..3].into());
12    if super::data_format(pb)? == DataFormat::NHWC {
13        op = op.nhwc()
14    }
15    Ok(expand(op))
16}
17
18#[cfg(test)]
19mod tests {
20    #![allow(non_snake_case)]
21    use super::*;
22    use tract_hir::ops::cnn::{Conv, PaddingSpec};
23    use tract_ndarray::*;
24
25    fn mk(sizes: &[usize]) -> Tensor {
26        Array::range(1f32, sizes.iter().product::<usize>() as f32 + 1.0, 1.0)
27            .into_shape_with_order(sizes)
28            .unwrap()
29            .into()
30    }
31
32    fn make_conv(h_stride: usize, v_stride: usize, padding: PaddingSpec) -> Box<dyn InferenceOp> {
33        expand(Conv::default().nhwc().hwio().padding(padding).strides(tvec![v_stride, h_stride]))
34    }
35
36    fn verify(input: Tensor, filter: Tensor, stride: usize, padding: PaddingSpec, expect: &[f32]) {
37        let result = make_conv(stride, stride, padding)
38            .eval(tvec![input.into(), filter.into()])
39            .unwrap()
40            .remove(0);
41        assert_eq!(expect.len(), result.shape().iter().product::<usize>());
42        let found = result.to_array_view::<f32>().unwrap();
43        let expect = ArrayD::from_shape_vec(found.shape(), expect.to_vec()).unwrap();
44        assert_eq!(expect, found);
45    }
46
47    #[test]
48    fn testConv2D3CNoopFilter() {
49        verify(
50            mk(&[1, 2, 3, 3]),
51            tensor4(&[[[[1.0f32, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]]),
52            1,
53            PaddingSpec::Valid,
54            &[
55                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
56                16.0, 17.0, 18.0,
57            ],
58        )
59    }
60
61    #[test]
62    fn testConv2D1x1Filter() {
63        verify(
64            mk(&[1, 2, 3, 3]),
65            mk(&[1, 1, 3, 3]),
66            1,
67            PaddingSpec::Valid,
68            &[
69                30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0, 204.0,
70                174.0, 216.0, 258.0, 210.0, 261.0, 312.0,
71            ],
72        );
73    }
74
75    #[test]
76    fn testConv2D1x2Filter() {
77        verify(
78            mk(&[1, 2, 3, 3]),
79            mk(&[1, 2, 3, 3]),
80            1,
81            PaddingSpec::Valid,
82            &[231.0, 252.0, 273.0, 384.0, 423.0, 462.0, 690.0, 765.0, 840.0, 843.0, 936.0, 1029.0],
83        )
84    }
85
86    #[test]
87    fn testConv2D2x1Filter() {
88        verify(
89            mk(&[1, 2, 3, 3]),
90            mk(&[2, 1, 3, 3]),
91            1,
92            PaddingSpec::Valid,
93            &[465.0, 504.0, 543.0, 618.0, 675.0, 732.0, 771.0, 846.0, 921.0],
94        );
95    }
96
97    #[test]
98    fn testConv2D2x2Filter() {
99        verify(
100            mk(&[1, 2, 3, 3]),
101            mk(&[2, 2, 3, 3]),
102            1,
103            PaddingSpec::Valid,
104            &[2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0],
105        )
106    }
107
108    #[test]
109    fn testConv2D2x2FilterStride2() {
110        verify(
111            mk(&[1, 2, 3, 3]),
112            mk(&[2, 2, 3, 3]),
113            2,
114            PaddingSpec::Valid,
115            &[2271.0, 2367.0, 2463.0],
116        )
117    }
118
119    #[test]
120    fn testConv2D2x2FilterStride2Same() {
121        verify(
122            mk(&[1, 2, 3, 3]),
123            mk(&[2, 2, 3, 3]),
124            2,
125            PaddingSpec::SameUpper,
126            &[2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0],
127        );
128    }
129
130    #[test]
131    fn test_conv_1() {
132        let conv = make_conv(1, 1, PaddingSpec::SameUpper);
133        // NHWC
134        let data = tensor4(&[[[[1f32]]]]);
135        // HWIO
136        let filter = tensor4(&[[[[0.0f32]]], [[[1.0]]], [[[0.0]]]]);
137        let exp = tensor4(&[[[[1f32]]]]);
138
139        let result = conv.eval(tvec![data.into(), filter.into()]).unwrap();
140        result[0].close_enough(&exp, Approximation::Approximate).unwrap()
141    }
142
143    #[test]
144    fn test_conv_2() {
145        let conv = make_conv(1, 1, PaddingSpec::SameUpper);
146        let data = tensor4(&[[[[142.3088f32], [48.891083]], [[208.3187], [-11.274994]]]]);
147        let filter =
148            tensor4(&[[[[160.72833f32]], [[107.84076]]], [[[247.50552]], [[-38.738464]]]]);
149        let exp = tensor4(&[[[[80142.31f32], [5067.5586]], [[32266.81], [-1812.2109]]]]);
150        let got = &conv.eval(tvec![data.into(), filter.into()]).unwrap()[0];
151        //println!("{:?}", got);
152        //println!("{:?}", exp);
153        exp.close_enough(got, true).unwrap()
154    }
155
156    #[test]
157    fn inference_1() {
158        let mut op = make_conv(1, 3, PaddingSpec::Valid);
159        let img = InferenceFact::from(Tensor::zero::<f32>(&[1, 1, 7, 1]).unwrap());
160        let ker = InferenceFact::from(Tensor::zero::<f32>(&[1, 3, 1, 1]).unwrap());
161        let any = InferenceFact::default();
162
163        let (_, output_facts, _) = op.infer_facts(tvec![&img, &ker], tvec![&any], tvec!()).unwrap();
164
165        assert_eq!(output_facts, tvec![f32::fact([1, 1, (7 - 3 + 1), 1]).into()]);
166    }
167
168    #[test]
169    fn inference_2() {
170        let mut op = make_conv(1, 1, PaddingSpec::SameUpper);
171        let img = InferenceFact::from(Tensor::zero::<f32>(&[1, 1, 1, 1]).unwrap());
172        let ker = InferenceFact::from(Tensor::zero::<f32>(&[1, 1, 1, 1]).unwrap());
173        let any = InferenceFact::default();
174
175        let (_, output_facts, _) = op.infer_facts(tvec![&img, &ker], tvec![&any], tvec!()).unwrap();
176
177        assert_eq!(output_facts, tvec![f32::fact([1, 1, 1, 1]).into()]);
178    }
179}