tract_hir/ops/cnn/
conv.rs

1use crate::infer::*;
2use crate::internal::*;
3use crate::ops::cast::cast;
4
5use tract_core::ops::cnn::conv::KernelFormat;
6use tract_core::ops::cnn::{PaddingSpec, PoolSpec};
7use tract_core::ops::nn::DataFormat;
8
9#[derive(Debug, Clone, Default, Hash)]
10pub struct Conv {
11    pub data_format: DataFormat,
12    pub kernel_fmt: KernelFormat,
13    pub dilations: Option<TVec<usize>>,
14    pub kernel_shape: Option<TVec<usize>>,
15    pub padding: PaddingSpec,
16    pub strides: Option<TVec<usize>>,
17    pub group: Option<usize>,
18
19    pub x_scale_input: Option<usize>,
20    pub x_zero_point_input: Option<usize>,
21    pub k_input: Option<usize>,
22    pub k_scale_input: Option<usize>,
23    pub k_zero_point_input: Option<usize>,
24
25    pub y_scale_input: Option<usize>,
26    pub y_zero_point_input: Option<usize>,
27
28    pub bias_input: Option<usize>,
29
30    pub override_output_datum_type: Option<DatumType>,
31}
32
33impl Conv {
34    pub fn hwc(self) -> Conv {
35        Conv { data_format: DataFormat::HWC, ..self }
36    }
37
38    pub fn nhwc(self) -> Conv {
39        Conv { data_format: DataFormat::NHWC, ..self }
40    }
41
42    pub fn hwio(self) -> Conv {
43        Conv { kernel_fmt: KernelFormat::HWIO, ..self }
44    }
45
46    pub fn padding(self, padding: PaddingSpec) -> Conv {
47        Conv { padding, ..self }
48    }
49
50    pub fn dilations(self, dilations: TVec<usize>) -> Conv {
51        Conv { dilations: Some(dilations), ..self }
52    }
53
54    pub fn group(self, group: usize) -> Conv {
55        Conv { group: Some(group), ..self }
56    }
57
58    pub fn strides(self, strides: TVec<usize>) -> Conv {
59        Conv { strides: Some(strides), ..self }
60    }
61
62    pub fn kernel_shape(self, kernel_shape: TVec<usize>) -> Conv {
63        Conv { kernel_shape: Some(kernel_shape), ..self }
64    }
65
66    pub fn bias_input(self, input: usize) -> Conv {
67        Conv { bias_input: Some(input), ..self }
68    }
69
70    pub fn x_zero_point_input(self, input: usize) -> Conv {
71        Conv { x_zero_point_input: Some(input), ..self }
72    }
73
74    pub fn k_zero_point_input(self, input: usize) -> Conv {
75        Conv { k_zero_point_input: Some(input), ..self }
76    }
77
78    pub fn output_shape<D: DimLike>(&self, ishape: &[D], kshape: &[usize]) -> TractResult<TVec<D>> {
79        debug_assert_eq!(
80            ishape.len()
81                + (self.data_format == DataFormat::HWC || self.data_format == DataFormat::CHW)
82                    as usize,
83            kshape.len(),
84            "Input and kernel ranks are inconsistent"
85        );
86        let mut result: TVec<D> = ishape.into();
87        let ishape = self.data_format.shape(ishape)?;
88        let spatial_rank = ishape.hw_rank();
89        let ones = tvec![1; spatial_rank];
90        let kernel_spatial_shape = self.kernel_fmt.hw(kshape);
91        let computed = self.padding.compute(
92            ishape.hw_dims(),
93            kernel_spatial_shape,
94            self.dilations.as_ref().unwrap_or(&ones),
95            self.strides.as_ref().unwrap_or(&ones),
96        );
97        let channels_out = *self.kernel_fmt.o(kshape);
98        result[ishape.c_axis()] = channels_out.into();
99        for (ix, d) in computed.iter().enumerate() {
100            result[ishape.h_axis() + ix] = d.convoluted.clone();
101        }
102        Ok(result)
103    }
104}
105
106impl Expansion for Conv {
107    fn name(&self) -> StaticName {
108        "ConvHir".into()
109    }
110
111    fn validation(&self) -> Validation {
112        Validation::Rounding
113    }
114
115    fn rules<'r, 'p: 'r, 's: 'r>(
116        &'s self,
117        s: &mut Solver<'r>,
118        inputs: &'p [TensorProxy],
119        outputs: &'p [TensorProxy],
120    ) -> InferenceResult {
121        if inputs.len() < 2 {
122            bail!("Wrong number of inputs. Expected 2 or more, got {}", inputs.len());
123        }
124        let has_n = self.data_format == DataFormat::NHWC || self.data_format == DataFormat::NCHW;
125        let k_input = &inputs[self.k_input.unwrap_or(1)];
126        if let Some(kshape) = &self.kernel_shape {
127            s.equals(&k_input.rank, kshape.len() as i64 + 2)?;
128            for (ix, dim) in kshape.iter().enumerate() {
129                s.equals(&k_input.shape[ix + self.kernel_fmt.h_axis()], TDim::from(*dim as i64))?;
130            }
131        }
132        s.equals(&inputs[0].rank, k_input.rank.bex() + (has_n as usize as i64 - 1))?;
133        s.equals(&outputs[0].rank, &inputs[0].rank)?;
134        check_output_arity(outputs, 1)?;
135        s.equals(&inputs[0].datum_type, &k_input.datum_type)?;
136        if let Some(dt) = self.override_output_datum_type {
137            s.equals(&outputs[0].datum_type, dt)?;
138        } else {
139            s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
140        }
141        if let Some(bias) = self.bias_input {
142            // bias datum type is ill-defined. no check
143            s.equals(&inputs[bias].rank, 1)?;
144            s.given(&k_input.rank, move |s, krank| {
145                let filter_o = match self.kernel_fmt {
146                    KernelFormat::OIHW => &k_input.shape[0],
147                    KernelFormat::HWIO => &k_input.shape[krank as usize - 1],
148                    KernelFormat::OHWI => &k_input.shape[0],
149                };
150                s.equals(&inputs[bias].shape[0], filter_o)
151            })?
152        }
153        s.given_2(&inputs[0].rank, &k_input.rank, move |s, irank, krank| {
154            let input_c =
155                if self.data_format == DataFormat::NHWC || self.data_format == DataFormat::HWC {
156                    &inputs[0].shape[irank as usize - 1]
157                } else {
158                    &inputs[0].shape[1]
159                };
160            let filter_i = match self.kernel_fmt {
161                KernelFormat::OIHW => &k_input.shape[1],
162                KernelFormat::HWIO => &k_input.shape[krank as usize - 2],
163                KernelFormat::OHWI => &k_input.shape[krank as usize - 1],
164            };
165            s.equals(input_c.bex(), self.group.unwrap_or(1) as i64 * filter_i.bex())
166        })?;
167        s.given_2(&inputs[0].shape, &k_input.shape, move |s, ishape, kshape| {
168            if let Some(kshape) =
169                kshape.iter().map(|d| d.to_usize().ok()).collect::<Option<TVec<_>>>()
170            {
171                let oshape = self.output_shape(&ishape, &kshape)?;
172                s.equals(&outputs[0].shape, oshape)?;
173            }
174            Ok(())
175        })
176    }
177
178    fn wire(
179        &self,
180        prefix: &str,
181        model: &mut TypedModel,
182        inputs: &[OutletId],
183    ) -> TractResult<TVec<OutletId>> {
184        let kernel_input = self.k_input.unwrap_or(1);
185        let kernel_fact = model.outlet_fact(inputs[kernel_input])?.clone();
186        let input = model.outlet_fact(inputs[0])?.clone();
187        let input_shape = self.data_format.shape(&input.shape)?;
188        let kernel_full_shape =
189            kernel_fact.shape.as_concrete().context("Expect concrete shape for kernel")?;
190        let group = self.group.unwrap_or(1);
191        let input_channels = self.kernel_fmt.input_channels(kernel_full_shape, group).into_owned();
192        let output_channels =
193            self.kernel_fmt.output_channels(kernel_full_shape, group).into_owned();
194        if input_shape.c_dim() != &input_channels.to_dim() {
195            bail!("Input has {} channels, kernel expects {}", input_shape.c_dim(), input_channels)
196        }
197        let bias_dt =
198            if input.datum_type.is_float() { input.datum_type } else { i32::datum_type() };
199        let mut bias = if let Some(slot) = self.bias_input {
200            model.wire_node(format!("{prefix}.bias"), cast(bias_dt), &[inputs[slot]])?[0]
201        } else {
202            model.add_const(format!("{prefix}.bias"), Tensor::zero_scalar_dt(bias_dt)?)?
203        };
204        while let Some(axis) = model
205            .outlet_fact(bias)?
206            .shape
207            .to_tvec()
208            .iter()
209            .enumerate()
210            .rev()
211            .position(|(_, dim)| dim.is_one())
212        {
213            bias =
214                model.wire_node(format!("{prefix}.bias_rm_{axis}"), AxisOp::Rm(axis), &[bias])?[0];
215        }
216        let mut wires = vec![inputs[0], inputs[kernel_input], bias];
217        let pool_spec = PoolSpec {
218            data_format: self.data_format,
219            padding: self.padding.clone(),
220            strides: self.strides.clone(),
221            dilations: self.dilations.clone(),
222            kernel_shape: self.kernel_fmt.hw(kernel_full_shape).into(),
223            input_channels,
224            output_channels,
225        };
226
227        let quantized = self.k_zero_point_input.is_some()
228            || self.k_scale_input.is_some()
229            || self.x_zero_point_input.is_some()
230            || self.x_scale_input.is_some()
231            || self.y_zero_point_input.is_some()
232            || self.y_scale_input.is_some();
233        let output_type = self.override_output_datum_type.unwrap_or(input.datum_type);
234        if quantized {
235            let zero = model.add_const(format!("{prefix}.zero"), tensor0(0i32))?;
236            let one = model.add_const(format!("{prefix}.one"), tensor0(1f32))?;
237
238            macro_rules! qp {
239                ($id: ident, $def: expr, $ty: ty) => {
240                    let wire = self.$id.map(|i| inputs[i]).unwrap_or($def);
241                    let wire = model.wire_node(
242                        format!("{prefix}.cast_{}", stringify!($id)),
243                        cast(<$ty>::datum_type()),
244                        &[wire],
245                    )?[0];
246                    wires.push(wire);
247                };
248            }
249
250            qp!(x_zero_point_input, zero, i32);
251            qp!(x_scale_input, one, f32);
252            qp!(k_zero_point_input, zero, i32);
253            qp!(k_scale_input, one, f32);
254            qp!(y_zero_point_input, zero, i32);
255            qp!(y_scale_input, one, f32);
256        };
257
258        let reduced = tract_core::ops::cnn::Conv::new(
259            pool_spec,
260            self.kernel_fmt,
261            group,
262            Some(output_type).filter(|_| quantized),
263        );
264        model.wire_node(prefix, reduced, &wires)
265    }
266}
267
268#[cfg(test)]
269mod test {
270    use super::*;
271    use crate::setup_test_logger;
272
273    #[test]
274    fn test_infer_with_known_kshape() {
275        let mut op = expand(Conv::default().strides(tvec![2, 2]).kernel_shape(tvec![3, 3]));
276        let ifact = f32::fact([1, 1, 7, 5]).into();
277        let kfact = f32::fact([1, 1, 3, 3]).into();
278        let ofact = InferenceFact::default();
279        let facts = op.infer_facts(tvec!(&ifact, &kfact), tvec!(&ofact), tvec!()).unwrap();
280        assert_eq!(facts.1, tvec!(f32::fact([1, 1, 3, 2]).into()));
281    }
282
283    #[test]
284    fn test_infer_channels() {
285        let mut op = expand(Conv::default()); // NCHW - OIHW
286        let ifact = f32::fact([1, 2, 1, 1]).into();
287        let kfact = f32::fact([3, 2, 1, 1]).into();
288        let ofact = InferenceFact::default();
289        let facts = op.infer_facts(tvec!(&ifact, &kfact), tvec!(&ofact), tvec!()).unwrap();
290        assert_eq!(facts.1, tvec!(f32::fact([1, 3, 1, 1]).into()));
291    }
292
293    #[test]
294    fn test_infer_onnx_strides_no_padding() {
295        let mut op = expand(Conv::default().strides(tvec![2, 2]));
296        let ifact = f32::fact([1, 1, 7, 5]).into();
297        let kfact = f32::fact([1, 1, 3, 3]).into();
298        let ofact = InferenceFact::default();
299        let facts = op.infer_facts(tvec!(&ifact, &kfact), tvec!(&ofact), tvec!()).unwrap();
300        assert_eq!(facts.1, tvec!(f32::fact([1, 1, 3, 2]).into()));
301    }
302
303    #[test]
304    fn test_infer_nhwc_1() {
305        let mut op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
306        let ifact = f32::fact([1, 2, 2, 2]).into();
307        let kfact = f32::fact([2, 2, 2, 1]).into();
308        let ofact = InferenceFact::default();
309        let facts = op.infer_facts(tvec!(&ifact, &kfact), tvec!(&ofact), tvec!()).unwrap();
310        assert_eq!(facts.1, tvec!(f32::fact([1, 2, 2, 1]).into()));
311    }
312
313    #[test]
314    fn test_eval_nhwc_1() -> TractResult<()> {
315        setup_test_logger();
316        let op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
317        let res = op.eval(tvec!(
318            Tensor::zero::<f32>(&[1, 2, 2, 2]).unwrap().into_tvalue(),
319            Tensor::zero::<f32>(&[2, 2, 2, 1]).unwrap().into_tvalue(),
320        ))?;
321        Tensor::zero::<f32>(&[1, 2, 2, 1]).unwrap().close_enough(&res[0], false)
322    }
323
324    #[test]
325    fn test_infer_nhwc_2() {
326        setup_test_logger();
327        let mut op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
328        let ifact = f32::fact([1, 1, 2, 2]).into();
329        let kfact = f32::fact([2, 1, 2, 1]).into();
330        let ofact = InferenceFact::default();
331        let facts = op.infer_facts(tvec!(&ifact, &kfact), tvec!(&ofact), tvec!()).unwrap();
332        assert_eq!(facts.1, tvec!(f32::fact([1, 1, 2, 1]).into()));
333    }
334
335    #[test]
336    fn test_eval_nhwc_2() {
337        setup_test_logger();
338        let op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
339        let i = tensor4(&[[[[0.0f32, 0.0], [1.0, 0.0]]]]);
340        let k = tensor4(&[[[[0.0f32], [0.0]], [[1.0], [0.0]]]]);
341        let e = tensor4(&[[[[1.0f32], [0.0]]]]);
342        let res = op.eval(tvec!(i.into(), k.into())).unwrap();
343        res[0].close_enough(&e, Approximation::Approximate).unwrap();
344    }
345
346    #[test]
347    fn test_eval_nhwc_3() {
348        setup_test_logger();
349        let op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
350        let i = tensor4(&[[[[0.0f32, 1.0], [2.0, 3.0]], [[10.0, 11.0], [12.0, 13.0]]]]);
351        let k = tensor4(&[[[[1.0f32, 0.0], [0.0, 1.0]]]]);
352        let res = op.eval(tvec!(i.clone().into(), k.into())).unwrap();
353        res[0].close_enough(&i, Approximation::Approximate).unwrap()
354    }
355
356    #[test]
357    fn test_eval_nhwc_batch() {
358        setup_test_logger();
359        let op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
360        let result = op
361            .eval(tvec!(
362                tensor4(&[[[[2.0f32]]], [[[0.0f32]]]]).into(),
363                tensor4(&[[[[1.0f32]]]]).into()
364            ))
365            .unwrap();
366        result[0]
367            .close_enough(&tensor4(&[[[[2.0f32]]], [[[0.0f32]]]]), Approximation::Approximate)
368            .unwrap();
369    }
370
371    #[test]
372    fn test_infer_ntc_simple() {
373        let mut op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
374        let ifact = f32::fact([1, 2, 1]).into();
375        let kfact = f32::fact([1, 1, 1]).into();
376        let ofact = InferenceFact::default();
377        let facts = op.infer_facts(tvec!(&ifact, &kfact), tvec!(&ofact), tvec!()).unwrap();
378        assert_eq!(facts.1, tvec!(f32::fact([1, 2, 1]).into()));
379    }
380
381    #[test]
382    fn test_eval_ntc_simple() {
383        let op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
384        let result = op
385            .eval(tvec!(tensor3(&[[[2.0f32], [0.0f32]]]).into(), tensor3(&[[[1.0f32]]]).into()))
386            .unwrap();
387        result[0]
388            .close_enough(&tensor3(&[[[2.0f32], [0.0f32]]]), Approximation::Approximate)
389            .unwrap();
390    }
391
392    #[test]
393    fn test_infer_ntc_batch() {
394        let mut op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
395        let ifact = f32::fact([2, 1, 1]).into();
396        let kfact = f32::fact([1, 1, 1]).into();
397        let ofact = InferenceFact::default();
398        let facts = op.infer_facts(tvec!(&ifact, &kfact), tvec!(&ofact), tvec!()).unwrap();
399        assert_eq!(facts.1, tvec!(f32::fact([2, 1, 1]).into()));
400    }
401
402    #[test]
403    fn test_eval_ntc_batch() {
404        let op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
405        let result = op
406            .eval(tvec!(tensor3(&[[[2.0f32]], [[0.0f32]]]).into(), tensor3(&[[[1.0f32]]]).into()))
407            .unwrap();
408        result[0]
409            .close_enough(&tensor3(&[[[2.0f32]], [[0.0f32]]]), Approximation::Approximate)
410            .unwrap();
411    }
412
413    #[test]
414    fn test_infer_ntc_channel() {
415        let mut op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
416        let ifact = f32::fact([1, 1, 2]).into();
417        let kfact = f32::fact([1, 2, 1]).into();
418        let ofact = InferenceFact::default();
419        let facts = op.infer_facts(tvec!(&ifact, &kfact), tvec!(&ofact), tvec!()).unwrap();
420        assert_eq!(facts.1, tvec!(f32::fact([1, 1, 1]).into()));
421    }
422
423    #[test]
424    fn test_eval_ntc_channel() {
425        let op = expand(Conv::default().nhwc().hwio().padding(PaddingSpec::SameUpper));
426        let result = op
427            .eval(tvec!(
428                tensor3(&[[[2.0f32, 0.0f32]]]).into(),
429                tensor3(&[[[1.0f32], [0.0f32]]]).into()
430            ))
431            .unwrap();
432        result[0].close_enough(&tensor3(&[[[2.0f32]]]), Approximation::Approximate).unwrap();
433    }
434}