1use crate::infer::*;
2use crate::internal::*;
3use crate::ops::cast::cast;
4
5use tract_core::ops::cnn::conv::KernelFormat;
6use tract_core::ops::cnn::{PaddingSpec, PoolSpec};
7use tract_core::ops::nn::DataFormat;
8
9#[derive(Debug, Clone, Default, Hash)]
10pub struct Conv {
11 pub data_format: DataFormat,
12 pub kernel_fmt: KernelFormat,
13 pub dilations: Option<TVec<usize>>,
14 pub kernel_shape: Option<TVec<usize>>,
15 pub padding: PaddingSpec,
16 pub strides: Option<TVec<usize>>,
17 pub group: Option<usize>,
18
19 pub x_scale_input: Option<usize>,
20 pub x_zero_point_input: Option<usize>,
21 pub k_input: Option<usize>,
22 pub k_scale_input: Option<usize>,
23 pub k_zero_point_input: Option<usize>,
24
25 pub y_scale_input: Option<usize>,
26 pub y_zero_point_input: Option<usize>,
27
28 pub bias_input: Option<usize>,
29
30 pub override_output_datum_type: Option<DatumType>,
31}
32
33impl Conv {
34 pub fn hwc(self) -> Conv {
35 Conv { data_format: DataFormat::HWC, ..self }
36 }
37
38 pub fn nhwc(self) -> Conv {
39 Conv { data_format: DataFormat::NHWC, ..self }
40 }
41
42 pub fn hwio(self) -> Conv {
43 Conv { kernel_fmt: KernelFormat::HWIO, ..self }
44 }
45
46 pub fn padding(self, padding: PaddingSpec) -> Conv {
47 Conv { padding, ..self }
48 }
49
50 pub fn dilations(self, dilations: TVec<usize>) -> Conv {
51 Conv { dilations: Some(dilations), ..self }
52 }
53
54 pub fn group(self, group: usize) -> Conv {
55 Conv { group: Some(group), ..self }
56 }
57
58 pub fn strides(self, strides: TVec<usize>) -> Conv {
59 Conv { strides: Some(strides), ..self }
60 }
61
62 pub fn kernel_shape(self, kernel_shape: TVec<usize>) -> Conv {
63 Conv { kernel_shape: Some(kernel_shape), ..self }
64 }
65
66 pub fn bias_input(self, input: usize) -> Conv {
67 Conv { bias_input: Some(input), ..self }
68 }
69
70 pub fn x_zero_point_input(self, input: usize) -> Conv {
71 Conv { x_zero_point_input: Some(input), ..self }
72 }
73
74 pub fn k_zero_point_input(self, input: usize) -> Conv {
75 Conv { k_zero_point_input: Some(input), ..self }
76 }
77
78 pub fn output_shape<D: DimLike>(&self, ishape: &[D], kshape: &[usize]) -> TractResult<TVec<D>> {
79 debug_assert_eq!(
80 ishape.len()
81 + (self.data_format == DataFormat::HWC || self.data_format == DataFormat::CHW)
82 as usize,
83 kshape.len(),
84 "Input and kernel ranks are inconsistent"
85 );
86 let mut result: TVec<D> = ishape.into();
87 let ishape = self.data_format.shape(ishape)?;
88 let spatial_rank = ishape.hw_rank();
89 let ones = tvec![1; spatial_rank];
90 let kernel_spatial_shape = self.kernel_fmt.hw(kshape);
91 let computed = self.padding.compute(
92 ishape.hw_dims(),
93 kernel_spatial_shape,
94 self.dilations.as_ref().unwrap_or(&ones),
95 self.strides.as_ref().unwrap_or(&ones),
96 );
97 let channels_out = *self.kernel_fmt.o(kshape);
98 result[ishape.c_axis()] = channels_out.into();
99 for (ix, d) in computed.iter().enumerate() {
100 result[ishape.h_axis() + ix] = d.convoluted.clone();
101 }
102 Ok(result)
103 }
104}
105
106impl Expansion for Conv {
107 fn name(&self) -> StaticName {
108 "ConvHir".into()
109 }
110
111 fn validation(&self) -> Validation {
112 Validation::Rounding
113 }
114
115 fn rules<'r, 'p: 'r, 's: 'r>(
116 &'s self,
117 s: &mut Solver<'r>,
118 inputs: &'p [TensorProxy],
119 outputs: &'p [TensorProxy],
120 ) -> InferenceResult {
121 if inputs.len() < 2 {
122 bail!("Wrong number of inputs. Expected 2 or more, got {}", inputs.len());
123 }
124 let has_n = self.data_format == DataFormat::NHWC || self.data_format == DataFormat::NCHW;
125 let k_input = &inputs[self.k_input.unwrap_or(1)];
126 if let Some(kshape) = &self.kernel_shape {
127 s.equals(&k_input.rank, kshape.len() as i64 + 2)?;
128 for (ix, dim) in kshape.iter().enumerate() {
129 s.equals(&k_input.shape[ix + self.kernel_fmt.h_axis()], TDim::from(*dim as i64))?;
130 }
131 }
132 s.equals(&inputs[0].rank, k_input.rank.bex() + (has_n as usize as i64 - 1))?;
133 s.equals(&outputs[0].rank, &inputs[0].rank)?;
134 check_output_arity(outputs, 1)?;
135 s.equals(&inputs[0].datum_type, &k_input.datum_type)?;
136 if let Some(dt) = self.override_output_datum_type {
137 s.equals(&outputs[0].datum_type, dt)?;
138 } else {
139 s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
140 }
141 if let Some(bias) = self.bias_input {
142 s.equals(&inputs[bias].rank, 1)?;
144 s.given(&k_input.rank, move |s, krank| {
145 let filter_o = match self.kernel_fmt {
146 KernelFormat::OIHW => &k_input.shape[0],
147 KernelFormat::HWIO => &k_input.shape[krank as usize - 1],
148 KernelFormat::OHWI => &k_input.shape[0],
149 };
150 s.equals(&inputs[bias].shape[0], filter_o)
151 })?
152 }
153 s.given_2(&inputs[0].rank, &k_input.rank, move |s, irank, krank| {
154 let input_c =
155 if self.data_format == DataFormat::NHWC || self.data_format == DataFormat::HWC {
156 &inputs[0].shape[irank as usize - 1]
157 } else {
158 &inputs[0].shape[1]
159 };
160 let filter_i = match self.kernel_fmt {
161 KernelFormat::OIHW => &k_input.shape[1],
162 KernelFormat::HWIO => &k_input.shape[krank as usize - 2],
163 KernelFormat::OHWI => &k_input.shape[krank as usize - 1],
164 };
165 s.equals(input_c.bex(), self.group.unwrap_or(1) as i64 * filter_i.bex())
166 })?;
167 s.given_2(&inputs[0].shape, &k_input.shape, move |s, ishape, kshape| {
168 if let Some(kshape) =
169 kshape.iter().map(|d| d.to_usize().ok()).collect::<Option<TVec<_>>>()
170 {
171 let oshape = self.output_shape(&ishape, &kshape)?;
172 s.equals(&outputs[0].shape, oshape)?;
173 }
174 Ok(())
175 })
176 }
177
178 fn wire(
179 &self,
180 prefix: &str,
181 model: &mut TypedModel,
182 inputs: &[OutletId],
183 ) -> TractResult<TVec<OutletId>> {
184 let kernel_input = self.k_input.unwrap_or(1);
185 let kernel_fact = model.outlet_fact(inputs[kernel_input])?.clone();
186 let input = model.outlet_fact(inputs[0])?.clone();
187 let input_shape = self.data_format.shape(&input.shape)?;
188 let kernel_full_shape =
189 kernel_fact.shape.as_concrete().context("Expect concrete shape for kernel")?;
190 let group = self.group.unwrap_or(1);
191 let input_channels = self.kernel_fmt.input_channels(kernel_full_shape, group).into_owned();
192 let output_channels =
193 self.kernel_fmt.output_channels(kernel_full_shape, group).into_owned();
194 if input_shape.c_dim() != &input_channels.to_dim() {
195 bail!("Input has {} channels, kernel expects {}", input_shape.c_dim(), input_channels)
196 }
197 let bias_dt =
198 if input.datum_type.is_float() { input.datum_type } else { i32::datum_type() };
199 let mut bias = if let Some(slot) = self.bias_input {
200 model.wire_node(format!("{prefix}.bias"), cast(bias_dt), &[inputs[slot]])?[0]
201 } else {
202 model.add_const(format!("{prefix}.bias"), Tensor::zero_scalar_dt(bias_dt)?)?
203 };
204 while let Some(axis) = model
205 .outlet_fact(bias)?
206 .shape
207 .to_tvec()
208 .iter()
209 .enumerate()
210 .rev()
211 .position(|(_, dim)| dim.is_one())
212 {
213 bias =
214 model.wire_node(format!("{prefix}.bias_rm_{axis}"), AxisOp::Rm(axis), &[bias])?[0];
215 }
216 let mut wires = vec![inputs[0], inputs[kernel_input], bias];
217 let pool_spec = PoolSpec {
218 data_format: self.data_format,
219 padding: self.padding.clone(),
220 strides: self.strides.clone(),
221 dilations: self.dilations.clone(),
222 kernel_shape: self.kernel_fmt.hw(kernel_full_shape).into(),
223 input_channels,
224 output_channels,
225 };
226
227 let quantized = self.k_zero_point_input.is_some()
228 || self.k_scale_input.is_some()
229 || self.x_zero_point_input.is_some()
230 || self.x_scale_input.is_some()
231 || self.y_zero_point_input.is_some()
232 || self.y_scale_input.is_some();
233 let output_type = self.override_output_datum_type.unwrap_or(input.datum_type);
234 if quantized {
235 let zero = model.add_const(format!("{prefix}.zero"), tensor0(0i32))?;
236 let one = model.add_const(format!("{prefix}.one"), tensor0(1f32))?;
237
238 macro_rules! qp {
239 ($id: ident, $def: expr, $ty: ty) => {
240 let wire = self.$id.map(|i| inputs[i]).unwrap_or($def);
241 let wire = model.wire_node(
242 format!("{prefix}.cast_{}", stringify!($id)),
243 cast(<$ty>::datum_type()),
244 &[wire],
245 )?[0];
246 wires.push(wire);
247 };
248 }
249
250 qp!(x_zero_point_input, zero, i32);
251 qp!(x_scale_input, one, f32);
252 qp!(k_zero_point_input, zero, i32);
253 qp!(k_scale_input, one, f32);
254 qp!(y_zero_point_input, zero, i32);
255 qp!(y_scale_input, one, f32);
256 };
257
258 let reduced = tract_core::ops::cnn::Conv::new(
259 pool_spec,
260 self.kernel_fmt,
261 group,
262 Some(output_type).filter(|_| quantized),
263 );
264 model.wire_node(prefix, reduced, &wires)
265 }
266}
267
268#[cfg(test)]
269mod test {
270 use super::*;
271 use crate::setup_test_logger;
272
273 #[test]
274 fn test_infer_with_known_kshape() {
275 let mut op = expand(Conv::default().strides(tvec![2, 2]).kernel_shape(tvec![3, 3]));
276 let ifact = f32::fact([1, 1, 7, 5]).into();
277 let kfact = f32::fact([1, 1, 3, 3]).into();
278 let ofact = InferenceFact::default();
279 let facts = op.infer_facts(tvec!(&ifact, &kfact), tvec!(&ofact), tvec!()).unwrap();
280 assert_eq!(facts.1, tvec!(f32::fact([1, 1, 3, 2]).into()));
281 }
282
283 #[test]
284 fn test_infer_channels() {
285 let mut op = expand(Conv::default()); let ifact = f32::fact([1, 2, 1, 1]).into();
287 let kfact = f32::fact([3, 2, 1, 1]).into();
288 let ofact = InferenceFact::default();
289 let facts = op.infer_facts(tvec!(&ifact, &kfact), tvec!(&ofact), tvec!()).unwrap();
290 assert_eq!(facts.1, tvec!(f32::fact([1, 3, 1, 1]).into()));
291 }
292
293 #[test]
294 fn test_infer_onnx_strides_no_padding() {
295 let mut op = expand(Conv::default().strides(tvec![2, 2]));
296 let ifact = f32::fact([1, 1, 7, 5]).into();
297 let kfact = f32::fact([1, 1, 3, 3]).into();
298 let ofact = InferenceFact::default();
299 let facts = op.infer_facts(tvec!(&ifact, &kfact), tvec!(&ofact), tvec!()).unwrap();
300 assert_eq!(facts.1, tvec!(f32::fact([1, 1, 3, 2]).into()));
301 }
302
303 #[test]
304 fn test_infer_nhwc_1() {
305 let mut op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
306 let ifact = f32::fact([1, 2, 2, 2]).into();
307 let kfact = f32::fact([2, 2, 2, 1]).into();
308 let ofact = InferenceFact::default();
309 let facts = op.infer_facts(tvec!(&ifact, &kfact), tvec!(&ofact), tvec!()).unwrap();
310 assert_eq!(facts.1, tvec!(f32::fact([1, 2, 2, 1]).into()));
311 }
312
313 #[test]
314 fn test_eval_nhwc_1() -> TractResult<()> {
315 setup_test_logger();
316 let op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
317 let res = op.eval(tvec!(
318 Tensor::zero::<f32>(&[1, 2, 2, 2]).unwrap().into_tvalue(),
319 Tensor::zero::<f32>(&[2, 2, 2, 1]).unwrap().into_tvalue(),
320 ))?;
321 Tensor::zero::<f32>(&[1, 2, 2, 1]).unwrap().close_enough(&res[0], false)
322 }
323
324 #[test]
325 fn test_infer_nhwc_2() {
326 setup_test_logger();
327 let mut op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
328 let ifact = f32::fact([1, 1, 2, 2]).into();
329 let kfact = f32::fact([2, 1, 2, 1]).into();
330 let ofact = InferenceFact::default();
331 let facts = op.infer_facts(tvec!(&ifact, &kfact), tvec!(&ofact), tvec!()).unwrap();
332 assert_eq!(facts.1, tvec!(f32::fact([1, 1, 2, 1]).into()));
333 }
334
335 #[test]
336 fn test_eval_nhwc_2() {
337 setup_test_logger();
338 let op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
339 let i = tensor4(&[[[[0.0f32, 0.0], [1.0, 0.0]]]]);
340 let k = tensor4(&[[[[0.0f32], [0.0]], [[1.0], [0.0]]]]);
341 let e = tensor4(&[[[[1.0f32], [0.0]]]]);
342 let res = op.eval(tvec!(i.into(), k.into())).unwrap();
343 res[0].close_enough(&e, Approximation::Approximate).unwrap();
344 }
345
346 #[test]
347 fn test_eval_nhwc_3() {
348 setup_test_logger();
349 let op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
350 let i = tensor4(&[[[[0.0f32, 1.0], [2.0, 3.0]], [[10.0, 11.0], [12.0, 13.0]]]]);
351 let k = tensor4(&[[[[1.0f32, 0.0], [0.0, 1.0]]]]);
352 let res = op.eval(tvec!(i.clone().into(), k.into())).unwrap();
353 res[0].close_enough(&i, Approximation::Approximate).unwrap()
354 }
355
356 #[test]
357 fn test_eval_nhwc_batch() {
358 setup_test_logger();
359 let op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
360 let result = op
361 .eval(tvec!(
362 tensor4(&[[[[2.0f32]]], [[[0.0f32]]]]).into(),
363 tensor4(&[[[[1.0f32]]]]).into()
364 ))
365 .unwrap();
366 result[0]
367 .close_enough(&tensor4(&[[[[2.0f32]]], [[[0.0f32]]]]), Approximation::Approximate)
368 .unwrap();
369 }
370
371 #[test]
372 fn test_infer_ntc_simple() {
373 let mut op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
374 let ifact = f32::fact([1, 2, 1]).into();
375 let kfact = f32::fact([1, 1, 1]).into();
376 let ofact = InferenceFact::default();
377 let facts = op.infer_facts(tvec!(&ifact, &kfact), tvec!(&ofact), tvec!()).unwrap();
378 assert_eq!(facts.1, tvec!(f32::fact([1, 2, 1]).into()));
379 }
380
381 #[test]
382 fn test_eval_ntc_simple() {
383 let op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
384 let result = op
385 .eval(tvec!(tensor3(&[[[2.0f32], [0.0f32]]]).into(), tensor3(&[[[1.0f32]]]).into()))
386 .unwrap();
387 result[0]
388 .close_enough(&tensor3(&[[[2.0f32], [0.0f32]]]), Approximation::Approximate)
389 .unwrap();
390 }
391
392 #[test]
393 fn test_infer_ntc_batch() {
394 let mut op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
395 let ifact = f32::fact([2, 1, 1]).into();
396 let kfact = f32::fact([1, 1, 1]).into();
397 let ofact = InferenceFact::default();
398 let facts = op.infer_facts(tvec!(&ifact, &kfact), tvec!(&ofact), tvec!()).unwrap();
399 assert_eq!(facts.1, tvec!(f32::fact([2, 1, 1]).into()));
400 }
401
402 #[test]
403 fn test_eval_ntc_batch() {
404 let op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
405 let result = op
406 .eval(tvec!(tensor3(&[[[2.0f32]], [[0.0f32]]]).into(), tensor3(&[[[1.0f32]]]).into()))
407 .unwrap();
408 result[0]
409 .close_enough(&tensor3(&[[[2.0f32]], [[0.0f32]]]), Approximation::Approximate)
410 .unwrap();
411 }
412
413 #[test]
414 fn test_infer_ntc_channel() {
415 let mut op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
416 let ifact = f32::fact([1, 1, 2]).into();
417 let kfact = f32::fact([1, 2, 1]).into();
418 let ofact = InferenceFact::default();
419 let facts = op.infer_facts(tvec!(&ifact, &kfact), tvec!(&ofact), tvec!()).unwrap();
420 assert_eq!(facts.1, tvec!(f32::fact([1, 1, 1]).into()));
421 }
422
423 #[test]
424 fn test_eval_ntc_channel() {
425 let op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
426 let result = op
427 .eval(tvec!(
428 tensor3(&[[[2.0f32, 0.0f32]]]).into(),
429 tensor3(&[[[1.0f32], [0.0f32]]]).into()
430 ))
431 .unwrap();
432 result[0].close_enough(&tensor3(&[[[2.0f32]]]), Approximation::Approximate).unwrap();
433 }
434}