1use tract_hir::internal::*;
2use tract_hir::ops::cnn;
3use tract_hir::ops::nn::DataFormat;
4
5use crate::model::ParsingContext;
6use crate::tfpb::tensorflow::NodeDef;
7
8pub fn conv2d(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
9 let strides = super::strides(pb)?;
10 let mut op =
11 cnn::Conv::default().hwio().padding(super::padding(pb)?).strides(strides[1..3].into());
12 if super::data_format(pb)? == DataFormat::NHWC {
13 op = op.nhwc()
14 }
15 Ok(expand(op))
16}
17
18#[cfg(test)]
19mod tests {
20 #![allow(non_snake_case)]
21 use super::*;
22 use tract_hir::ops::cnn::{Conv, PaddingSpec};
23 use tract_ndarray::*;
24
25 fn mk(sizes: &[usize]) -> Tensor {
26 Array::range(1f32, sizes.iter().product::<usize>() as f32 + 1.0, 1.0)
27 .into_shape_with_order(sizes)
28 .unwrap()
29 .into()
30 }
31
32 fn make_conv(h_stride: usize, v_stride: usize, padding: PaddingSpec) -> Box<dyn InferenceOp> {
33 expand(Conv::default().nhwc().hwio().padding(padding).strides(tvec![v_stride, h_stride]))
34 }
35
36 fn verify(input: Tensor, filter: Tensor, stride: usize, padding: PaddingSpec, expect: &[f32]) {
37 let result = make_conv(stride, stride, padding)
38 .eval(tvec![input.into(), filter.into()])
39 .unwrap()
40 .remove(0);
41 assert_eq!(expect.len(), result.shape().iter().product::<usize>());
42 let found = result.to_array_view::<f32>().unwrap();
43 let expect = ArrayD::from_shape_vec(found.shape(), expect.to_vec()).unwrap();
44 assert_eq!(expect, found);
45 }
46
47 #[test]
48 fn testConv2D3CNoopFilter() {
49 verify(
50 mk(&[1, 2, 3, 3]),
51 tensor4(&[[[[1.0f32, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]]),
52 1,
53 PaddingSpec::Valid,
54 &[
55 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
56 16.0, 17.0, 18.0,
57 ],
58 )
59 }
60
61 #[test]
62 fn testConv2D1x1Filter() {
63 verify(
64 mk(&[1, 2, 3, 3]),
65 mk(&[1, 1, 3, 3]),
66 1,
67 PaddingSpec::Valid,
68 &[
69 30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0, 204.0,
70 174.0, 216.0, 258.0, 210.0, 261.0, 312.0,
71 ],
72 );
73 }
74
75 #[test]
76 fn testConv2D1x2Filter() {
77 verify(
78 mk(&[1, 2, 3, 3]),
79 mk(&[1, 2, 3, 3]),
80 1,
81 PaddingSpec::Valid,
82 &[231.0, 252.0, 273.0, 384.0, 423.0, 462.0, 690.0, 765.0, 840.0, 843.0, 936.0, 1029.0],
83 )
84 }
85
86 #[test]
87 fn testConv2D2x1Filter() {
88 verify(
89 mk(&[1, 2, 3, 3]),
90 mk(&[2, 1, 3, 3]),
91 1,
92 PaddingSpec::Valid,
93 &[465.0, 504.0, 543.0, 618.0, 675.0, 732.0, 771.0, 846.0, 921.0],
94 );
95 }
96
97 #[test]
98 fn testConv2D2x2Filter() {
99 verify(
100 mk(&[1, 2, 3, 3]),
101 mk(&[2, 2, 3, 3]),
102 1,
103 PaddingSpec::Valid,
104 &[2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0],
105 )
106 }
107
108 #[test]
109 fn testConv2D2x2FilterStride2() {
110 verify(
111 mk(&[1, 2, 3, 3]),
112 mk(&[2, 2, 3, 3]),
113 2,
114 PaddingSpec::Valid,
115 &[2271.0, 2367.0, 2463.0],
116 )
117 }
118
119 #[test]
120 fn testConv2D2x2FilterStride2Same() {
121 verify(
122 mk(&[1, 2, 3, 3]),
123 mk(&[2, 2, 3, 3]),
124 2,
125 PaddingSpec::SameUpper,
126 &[2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0],
127 );
128 }
129
130 #[test]
131 fn test_conv_1() {
132 let conv = make_conv(1, 1, PaddingSpec::SameUpper);
133 let data = tensor4(&[[[[1f32]]]]);
135 let filter = tensor4(&[[[[0.0f32]]], [[[1.0]]], [[[0.0]]]]);
137 let exp = tensor4(&[[[[1f32]]]]);
138
139 let result = conv.eval(tvec![data.into(), filter.into()]).unwrap();
140 result[0].close_enough(&exp, Approximation::Approximate).unwrap()
141 }
142
143 #[test]
144 fn test_conv_2() {
145 let conv = make_conv(1, 1, PaddingSpec::SameUpper);
146 let data = tensor4(&[[[[142.3088f32], [48.891083]], [[208.3187], [-11.274994]]]]);
147 let filter =
148 tensor4(&[[[[160.72833f32]], [[107.84076]]], [[[247.50552]], [[-38.738464]]]]);
149 let exp = tensor4(&[[[[80142.31f32], [5067.5586]], [[32266.81], [-1812.2109]]]]);
150 let got = &conv.eval(tvec![data.into(), filter.into()]).unwrap()[0];
151 exp.close_enough(got, true).unwrap()
154 }
155
156 #[test]
157 fn inference_1() {
158 let mut op = make_conv(1, 3, PaddingSpec::Valid);
159 let img = InferenceFact::from(Tensor::zero::<f32>(&[1, 1, 7, 1]).unwrap());
160 let ker = InferenceFact::from(Tensor::zero::<f32>(&[1, 3, 1, 1]).unwrap());
161 let any = InferenceFact::default();
162
163 let (_, output_facts, _) = op.infer_facts(tvec![&img, &ker], tvec![&any], tvec!()).unwrap();
164
165 assert_eq!(output_facts, tvec![f32::fact([1, 1, (7 - 3 + 1), 1]).into()]);
166 }
167
168 #[test]
169 fn inference_2() {
170 let mut op = make_conv(1, 1, PaddingSpec::SameUpper);
171 let img = InferenceFact::from(Tensor::zero::<f32>(&[1, 1, 1, 1]).unwrap());
172 let ker = InferenceFact::from(Tensor::zero::<f32>(&[1, 1, 1, 1]).unwrap());
173 let any = InferenceFact::default();
174
175 let (_, output_facts, _) = op.infer_facts(tvec![&img, &ker], tvec![&any], tvec!()).unwrap();
176
177 assert_eq!(output_facts, tvec![f32::fact([1, 1, 1, 1]).into()]);
178 }
179}