tract_core/ops/cnn/conv/
conv.rs

1use tract_data::itertools::izip;
2use tract_num_traits::Zero;
3
4use crate::internal::*;
5use crate::model::*;
6use crate::ops;
7use crate::ops::array::Pad;
8use crate::ops::array::PadMode;
9use crate::ops::binary::TypedBinOp;
10use crate::ops::cast::cast;
11use crate::ops::cnn::conv::lazy_im2col::LazyIm2Col;
12use crate::ops::cnn::conv::lazy_im2col::LazyIm2colParams;
13use crate::ops::cnn::wire_reshape_bias_for_bin;
14use crate::ops::cnn::PaddingSpec::*;
15use crate::ops::einsum::EinSum;
16use crate::ops::math::{add, div, mul, sub};
17use crate::ops::math::{Add, Div, Mul, Sub};
18use crate::ops::matmul::optimized::AddMatMulGeometry;
19use crate::ops::matmul::optimized::MapOutputAxisToInput;
20use crate::ops::matmul::pack::OptMatMulPack;
21use crate::ops::matmul::quant::wire_ensure_q8_flavour;
22use crate::ops::matmul::ModePicker;
23use crate::ops::nn::Reduce;
24
25use super::depth_wise::DepthWise;
26use super::im2col::Im2Col;
27use crate::ops::cnn::conv::KernelFormat;
28use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry, PoolSpec};
29use crate::ops::matmul::optimized::{OptMatMul, ProtoFusedSpec};
30use crate::ops::nn::{BaseDataShape, DataFormat, DataShape};
31
32use tract_linalg::frame::PackedFormat;
33use tract_linalg::mmm::MatMatMul;
34
35#[derive(Debug, Clone, new, Hash)]
36pub struct Conv {
37    pub pool_spec: PoolSpec,
38    pub kernel_fmt: KernelFormat,
39    pub group: usize,
40    // None -> floats
41    // Some(I32) -> output is I32 (use quantized kernels, but output will be i32). last 2 Q inputs
42    // are ignored
43    // Some(QXX) -> quantized XX, but parameters are ignored (I8, U8, or I32) in favor of last 2 Q inputs
44    pub q_params: Option<DatumType>,
45}
46
47impl Conv {
48    pub fn input_channels(&self) -> usize {
49        self.pool_spec.input_channels
50    }
51
52    pub fn output_channels(&self) -> usize {
53        self.pool_spec.output_channels
54    }
55
56    pub fn wire_kernel_as_g_o_ihw(
57        &self,
58        model: &mut TypedModel,
59        name: &str,
60        mut kernel: OutletId,
61    ) -> TractResult<TVec<OutletId>> {
62        let fact = model.outlet_fact(kernel)?;
63        for (ix, op) in self
64            .kernel_fmt
65            .kernel_as_group_o_ihw_ops(&fact.shape, self.group)
66            .into_iter()
67            .enumerate()
68        {
69            kernel = model.wire_node(format!("{name}.prep_kernel.{ix}"), op, &[kernel])?[0];
70        }
71        Ok(tvec!(kernel))
72    }
73
74    fn wire_pack_g_o_ihw(
75        &self,
76        model: &mut TypedModel,
77        name: &str,
78        format: PackedFormat,
79        kernel: OutletId,
80    ) -> TractResult<OutletId> {
81        Ok(model.wire_node(
82            format!("{name}.prep_kernel.pack"),
83            OptMatMulPack {
84                packers: vec![format],
85                k_axis: 2,
86                mn_axis: 1,
87                mode_picker: ModePicker::Single,
88            },
89            &[kernel],
90        )?[0])
91    }
92
93    // group,bias
94    fn wire_bias_as_non_linear(
95        &self,
96        model: &mut TypedModel,
97        name: &str,
98        bias: OutletId,
99        c_group_axis: usize,
100    ) -> TractResult<(ProtoFusedSpec, OutletId)> {
101        use tract_linalg::BinOp::Add;
102        let fact = model.outlet_fact(bias)?;
103        if fact.shape.volume().is_one() {
104            Ok((ProtoFusedSpec::BinScalar(2, Add), bias))
105        } else {
106            let bias = AxisOp::wire_split_axis(
107                model,
108                format!("{name}.reformat_bias"),
109                bias,
110                0,
111                self.group,
112            )?[0];
113            let pfs =
114                ProtoFusedSpec::BinPerRow(2, Add, MapOutputAxisToInput(tvec!((c_group_axis, 0))));
115            Ok((pfs, bias))
116        }
117    }
118
119    pub unsafe fn wire_as_quant_im2col(
120        &self,
121        model: &mut TypedModel,
122        name: &str,
123        wires: &[OutletId],
124    ) -> TractResult<TVec<OutletId>> {
125        ensure!(self.q_params.is_some());
126        use crate::ops::matmul::quant as qmm;
127
128        let c_dt = self.q_params.unwrap();
129        let &[mut x, mut kernel, bias, mut x0, x_scale, mut k0, mut k_scale, y0, y_scale] = wires
130        else {
131            bail!("Wrong number of inputs")
132        };
133        wire_ensure_q8_flavour(model, name, &mut kernel, "k", &mut k0, i8::datum_type())?;
134        wire_ensure_q8_flavour(model, name, &mut x, "x", &mut x0, i8::datum_type())?;
135
136        let b_fact = model.outlet_fact(x)?.clone();
137
138        let (_, _, k, n, mmm) = self.compute_geo(&b_fact)?;
139        let packing = 1; // FIXME
140        let output_shape = self.pool_spec.output_shape(&b_fact.shape)?;
141
142        if !model.outlet_fact(k_scale)?.shape.volume().is_one() {
143            // requant is performed before geo_reshape, so we need at most one geo axis to the
144            // right
145            if !output_shape.fmt.c_is_last() {
146                k_scale = model.wire_node(
147                    format!("{name}.a_scale_axis_fix"),
148                    AxisOp::Add(1),
149                    &[k_scale],
150                )?[0];
151            }
152        }
153
154        let abc_scale = qmm::combine_scales(model, name, k_scale, x_scale, y_scale)?;
155
156        let im2col = model.wire_node(
157            format!("{name}.im2col"),
158            Im2Col::new(self.pool_spec.clone(), self.group, k, &b_fact.shape, mmm.clone())?,
159            &[x, x0],
160        )?[0];
161
162        let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
163        let g_o_ihw_as_i32 =
164            model.wire_node(format!("{name}.kernel_as_i32"), cast(i32::datum_type()), &g_o_ihw)?;
165        let sum_ker_g_c_k = model.wire_node(
166            format!("{name}.sum_ker_g_c_k"),
167            Reduce::new(tvec!(2), ops::nn::Reducer::Sum),
168            &g_o_ihw_as_i32,
169        )?;
170        let sum_ker_a_g_c =
171            model.wire_node(format!("{name}.rm_k"), AxisOp::Rm(2), &sum_ker_g_c_k)?;
172        // align sum_A from G,C to "C" shape: N,HW,G,C (or N,G,C,HW)
173        let sum_ker_n_g_c = model.wire_node(
174            format!("{name}.sum_ker_n_g_c.axis_0"),
175            AxisOp::Add(0),
176            &sum_ker_a_g_c,
177        )?;
178        let hw_position = if self.pool_spec.data_format.c_is_last() { 1 } else { 3 };
179        let sum_ker = model.wire_node(
180            format!("{name}.sum_ker_n_g_c"),
181            AxisOp::Add(hw_position),
182            &sum_ker_n_g_c,
183        )?;
184
185        ensure!(mmm.packings()[packing].1.downcast_ref::<PackedFormat>().is_some());
186        let mut sum_x = model.wire_node(
187            format!("{name}.sum_x"),
188            super::QSumB { dt: b_fact.datum_type, n, r: mmm.nr(), k },
189            &[im2col],
190        )?;
191        // sum_b is N,G,HW. make it N,HW,G,C or N,G,C,HW
192        sum_x = model.wire_node(format!("{name}.add_c"), AxisOp::Add(2), &sum_x)?;
193        if self.pool_spec.data_format.c_is_last() {
194            sum_x =
195                model.wire_node(format!("{name}.transpose_sum_b"), AxisOp::Move(3, 1), &sum_x)?;
196        }
197
198        let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?;
199        let bias_name = &model.node(bias.node).name;
200        let bias =
201            model.wire_node(format!("{bias_name}.cast"), cast(mmm.internal_type()), &[bias])?[0];
202        let wire = self.wire_mm_weights_bias(
203            model,
204            name,
205            im2col,
206            g_o_ihw[0],
207            bias,
208            mmm,
209            packing,
210            i32::datum_type(),
211            mmm_output_shape.clone().into(),
212            k,
213            c_axis,
214            h_axis,
215        )?;
216
217        let wire = qmm::compensate_zero_points(
218            model,
219            name,
220            wire[0],
221            k.to_dim(),
222            k0,
223            x0,
224            sum_ker[0],
225            sum_x[0],
226        )?;
227
228        let wire = self.wire_remove_group(model, name, &[wire], &mmm_output_shape, c_axis)?;
229        let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
230        let wire = qmm::requant(model, name, wire[0], c_dt, abc_scale, y0)?;
231        Self::wire_geo_reshape(model, name, &[wire], &output_shape)
232    }
233
234    pub fn wire_remove_group<D: DimLike>(
235        &self,
236        model: &mut TypedModel,
237        name: &str,
238        wire: &[OutletId],
239        mmm_output_shape: &[D],
240        c_axis: usize,
241    ) -> TractResult<TVec<OutletId>> {
242        let m = &mmm_output_shape[c_axis];
243        let op = if self.group == 1 {
244            AxisOp::Rm(c_axis - 1)
245        } else {
246            AxisOp::Reshape(
247                c_axis - 1,
248                tvec!(self.group.to_dim(), m.to_dim()),
249                tvec!(m.to_dim() * self.group),
250            )
251        };
252        model.wire_node(format!("{name}.reshape_group"), op, wire)
253    }
254
255    pub unsafe fn wire_as_im2col_pair(
256        &self,
257        model: &mut TypedModel,
258        name: &str,
259        wire: &[OutletId],
260    ) -> TractResult<TVec<OutletId>> {
261        let &[x, _kernel, bias] = wire else { bail!("Wrong number of inputs") };
262        let x_fact = model.outlet_fact(x)?.clone();
263        let b_dt = x_fact.datum_type;
264        let c_dt = crate::ops::matmul::output_type(x_fact.datum_type);
265
266        let (_, _, k, _, mmm) = self.compute_geo(&x_fact)?;
267        let geo_output_shape = self.pool_spec.output_shape(&x_fact.shape)?;
268        let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo_output_shape)?;
269
270        let padding = model.add_const(format!("{name}.b0"), Tensor::zero_scalar_dt(b_dt)?)?;
271
272        let mut wire: TVec<_> = wire.into();
273        wire[0] = model.wire_node(
274            format!("{name}.im2col"),
275            Im2Col::new(self.pool_spec.clone(), self.group, k, &x_fact.shape, mmm.clone())?,
276            &[wire[0], padding],
277        )?[0];
278
279        let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, wire[1])?;
280
281        let wire = self
282            .wire_mm_weights_bias(
283                model,
284                name,
285                wire[0],
286                g_o_ihw[0],
287                bias,
288                mmm,
289                0,
290                c_dt,
291                mmm_output_shape.clone().into(),
292                k.to_usize().unwrap(),
293                c_axis,
294                h_axis,
295            )
296            .context("in wire_opt_matmul")?;
297
298        let wire = self.wire_remove_group(model, name, &wire, &mmm_output_shape, c_axis)?;
299        let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
300        Self::wire_geo_reshape(model, name, &wire, &geo_output_shape)
301    }
302
303    // always have N and G. G is right before C, c_axis point to C, c_axis-1 points to G
304    fn mmm_output_shape<D: DimLike>(
305        &self,
306        output_shape: &BaseDataShape<D, TVec<D>>,
307    ) -> TractResult<(TVec<D>, usize, usize)> {
308        let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
309        let shape: BaseDataShape<D, TVec<D>> = output_shape.fmt.with_n().from_n_c_hw(
310            output_shape.n().cloned().unwrap_or_else(|| 1.into()),
311            output_shape.c().clone(),
312            tvec!(geo_collapsed_out),
313        )?;
314        let mut mmm_output_shape: TVec<D> = shape.shape.clone();
315        let mut c_axis = shape.c_axis();
316        let mut h_axis = shape.h_axis();
317        mmm_output_shape[shape.c_axis()] = mmm_output_shape[c_axis].clone() / self.group;
318        mmm_output_shape.insert(c_axis, self.group.into());
319        if h_axis > c_axis {
320            h_axis += 1;
321        }
322        c_axis += 1;
323        Ok((mmm_output_shape, c_axis, h_axis))
324    }
325
326    fn wire_rm_n_if_needed(
327        &self,
328        model: &mut TypedModel,
329        name: &str,
330        wire: &[OutletId],
331    ) -> TractResult<TVec<OutletId>> {
332        if self.pool_spec.data_format.has_n() {
333            Ok(wire.into())
334        } else {
335            model.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), wire)
336        }
337    }
338
339    fn wire_geo_reshape<D: DimLike>(
340        model: &mut TypedModel,
341        name: &str,
342        wire: &[OutletId],
343        output_shape: &BaseDataShape<D, TVec<D>>,
344    ) -> TractResult<TVec<OutletId>> {
345        let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
346        model
347            .wire_node(
348                name,
349                AxisOp::Reshape(
350                    output_shape.h_axis(),
351                    tvec!(geo_collapsed_out.to_dim()),
352                    output_shape.hw_dims().iter().map(|d| d.to_dim()).collect(),
353                ),
354                wire,
355            )
356            .context("in wire_geo_reshape")
357    }
358
359    pub unsafe fn wire_as_lazy_im2col(
360        &self,
361        model: &mut TypedModel,
362        name: &str,
363        wire: &[OutletId],
364    ) -> TractResult<TVec<OutletId>> {
365        let &[mut x, kernel, bias] = wire else { bail!("Wrong number of inputs") };
366        let mut x_fact = model.outlet_fact(x)?.clone();
367        let (geo, m, k, n, mmm) = self.compute_geo(&x_fact)?;
368        let packing = 0;
369        debug!("{name} as lazy_im2col: m={m} k={k} n={n} {mmm:?}");
370        let input_shape = x_fact.shape.as_concrete().unwrap().to_vec();
371        let mut geo = geo.to_concrete(&input_shape)?.into_owned();
372        let mut input_shape: DataShape = self.pool_spec.data_format.shape(input_shape.into())?;
373        let padding = self.pool_spec.computed_padding(input_shape.hw_dims());
374        if padding.iter().any(|axis| axis.pad_before != 0 || axis.pad_after != 0) {
375            let mut pads = vec![(0, 0); x_fact.rank()];
376            for (ix, ax) in padding.iter().enumerate() {
377                pads[input_shape.h_axis() + ix] = (ax.pad_before, ax.pad_after);
378            }
379            let op = crate::ops::array::Pad {
380                mode: crate::ops::array::PadMode::Constant(
381                    Tensor::zero_scalar_dt(x_fact.datum_type)?.into_arc_tensor(),
382                ),
383                pads,
384            };
385            x = model.wire_node(format!("{name}.pad"), op, &[x])?[0];
386            let valid_pool_spec = PoolSpec { padding: Valid, ..self.pool_spec.clone() };
387            x_fact = model.outlet_fact(x)?.clone();
388            let concrete_shape = x_fact.shape.as_concrete().unwrap();
389            input_shape = valid_pool_spec.data_format.shape(concrete_shape.into())?;
390            geo = valid_pool_spec
391                .compute_geo(&x_fact.shape)?
392                .to_concrete(concrete_shape)?
393                .into_owned();
394        }
395        let c_dt = crate::ops::matmul::output_type(x_fact.datum_type);
396        let c_stride = input_shape.c_stride();
397        let size_of_b = x_fact.datum_type.size_of() as isize;
398        let n_byte_offsets: Vec<isize> =
399            geo.patch.centers_offsets().into_iter().map(|x| x * size_of_b).collect();
400        let k_byte_offsets: Vec<isize> = (0..self.input_channels())
401            .flat_map(|ici| {
402                geo.patch
403                    .standard_layout_data_field
404                    .iter()
405                    .map(move |x| (x + (ici * c_stride) as isize) * size_of_b)
406            })
407            .collect();
408        let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo.output_shape)?;
409        let packer = mmm.packings()[packing]
410            .1
411            .downcast_ref::<PackedFormat>()
412            .with_context(|| {
413                format_err!(
414                    "Quand Im2Col expects regular packed format, got {:?}",
415                    mmm.packings()[packing].1
416                )
417            })?
418            .clone();
419        let params = LazyIm2colParams { packer, n_byte_offsets, k_byte_offsets };
420        let x = model.wire_node(
421            format!("{name}.lazyIm2col"),
422            LazyIm2Col { params: Arc::new(params) },
423            &[x],
424        )?[0];
425
426        let kernel = self.wire_kernel_as_g_o_ihw(model, name, kernel)?[0];
427        let wire = self.wire_mm_weights_bias(
428            model,
429            name,
430            x,
431            kernel,
432            bias,
433            mmm,
434            packing,
435            c_dt,
436            mmm_output_shape.clone().into(),
437            k,
438            c_axis,
439            h_axis,
440        )?;
441
442        let wire = self.wire_remove_group(model, name, &wire, &mmm_output_shape, c_axis)?;
443        let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
444        Self::wire_geo_reshape(model, name, &wire, &geo.output_shape)
445    }
446
447    #[allow(clippy::type_complexity)]
448    fn compute_geo(
449        &self,
450        input_fact: &TypedFact,
451    ) -> TractResult<(PoolGeometry, usize, usize, TDim, Box<dyn MatMatMul>)> {
452        let b_dt = input_fact.datum_type;
453        let acc = if b_dt.is_float() { b_dt } else { i32::datum_type() };
454
455        let geo = self.pool_spec.compute_geo(&input_fact.shape)?;
456
457        trace!("output channels: {:?}", self.output_channels());
458        let m = self.output_channels() / self.group;
459        let k = self.input_channels() * self.pool_spec.kernel_shape.iter().product::<usize>()
460            / self.group;
461        let n: TDim =
462            self.pool_spec.output_shape(&input_fact.shape)?.hw_dims().iter().cloned().product();
463
464        let mmm = tract_linalg::ops()
465            .mmm(acc, Some(m), Some(k), n.to_usize().ok())
466            .with_context(|| format!("No multiplier for {acc:?}, {m}x{k}x{n}",))?;
467
468        Ok((geo, m, k, n, mmm))
469    }
470
471    #[allow(clippy::too_many_arguments)]
472    fn wire_mm_weights_bias(
473        &self,
474        model: &mut TypedModel,
475        name: &str,
476        input: OutletId,
477        g_o_ihw: OutletId,
478        bias: OutletId,
479        mmm: Box<dyn MatMatMul>,
480        packing: usize,
481        c_datum_type: DatumType,
482        mmm_output_shape: ShapeFact,
483        k: usize,
484        c_m_axis: usize,
485        c_n_axis: usize,
486    ) -> TractResult<TVec<OutletId>> {
487        ensure!(model.outlet_fact(bias)?.datum_type == mmm.internal_type());
488        let a_pack = mmm.packings()[packing]
489            .0
490            .downcast_ref::<PackedFormat>()
491            .context("Conv expects wights in regular packed format")?
492            .clone();
493        let packed_ker = self
494            .wire_pack_g_o_ihw(model, name, a_pack, g_o_ihw)
495            .context("in kernel_as_packed_as")?;
496        let (mut c_to_a_axis_mapping, mut c_to_b_axis_mapping) = (tvec!(), tvec!());
497
498        c_to_a_axis_mapping.push((c_m_axis - 1, 0)); // Group
499        c_to_b_axis_mapping.push((0, 0)); // Batch
500        c_to_b_axis_mapping.push((c_m_axis - 1, 1)); // Group
501
502        let geo = AddMatMulGeometry {
503            k: k.to_dim(),
504            c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
505            c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
506        };
507        let mut ops: Vec<ProtoFusedSpec> =
508            vec![ProtoFusedSpec::AddMatMul { geo, a: 1, b: 0, packings: vec![(packing, None)] }];
509        let mut wires: TVec<OutletId> = tvec!(input, packed_ker);
510        let bias_fact = model.outlet_fact(bias)?;
511        if bias_fact.konst.is_none() || !bias_fact.konst.as_ref().unwrap().is_all_zero()? {
512            let (fused, bias) = self.wire_bias_as_non_linear(model, name, bias, c_m_axis - 1)?;
513            wires.push(bias);
514            ops.push(fused);
515        }
516        ops.push(ProtoFusedSpec::Store(vec![unsafe { mmm.c_view(c_m_axis, c_n_axis) }]));
517        model.wire_node(
518            format!("{name}.matmatmul"),
519            OptMatMul::new(
520                vec![mmm],
521                ModePicker::Single,
522                c_datum_type.fact(mmm_output_shape),
523                c_m_axis,
524                c_n_axis,
525                ops,
526                packing == 0 && self.group == 1,
527            )?,
528            &wires,
529        )
530    }
531
532    pub fn wire_as_depth_wise(
533        &self,
534        model: &mut TypedModel,
535        name: &str,
536        wire: &[OutletId],
537    ) -> TractResult<OutletId> {
538        let &[x, kernel, mut bias] = wire else { bail!("Wrong number of inputs") };
539        let x_fact = model.outlet_fact(x)?.clone();
540        let x_shape = x_fact.shape.as_concrete().unwrap();
541        let ConcretePoolGeometry { input_shape, patch, output_shape } =
542            self.pool_spec.compute_geo(&x_fact.shape)?.to_concrete(x_shape)?.into_owned();
543        let kernel = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
544        let c_axis = self.pool_spec.data_format.shape(x_shape)?.c_axis();
545        bias = wire_reshape_bias_for_bin(
546            model,
547            name,
548            bias,
549            x_fact.rank(),
550            c_axis,
551            self.output_channels(),
552        )?[0];
553        let op = DepthWise::new(patch, input_shape, output_shape);
554        Ok(model.wire_node(name, op, &[x, kernel[0], bias])?[0])
555    }
556
557    fn declutter_stride_slice_to_downsample(
558        &self,
559        model: &TypedModel,
560        node: &TypedNode,
561    ) -> TractResult<Option<TypedModelPatch>> {
562        let spatial_rank = self.pool_spec.rank();
563        if let Some(axis) = (0..spatial_rank).find(|&ax| {
564            self.pool_spec.stride(ax) > 1
565                && self.pool_spec.padding.valid_dim(ax, self.pool_spec.stride(ax) == 1)
566                && (self.pool_spec.kernel_shape[ax] == 1
567                    || self.pool_spec.dilation(ax) % self.pool_spec.stride(ax) == 0)
568        }) {
569            let input_fact = model.outlet_fact(node.inputs[0])?;
570            let downsample_factor = self.pool_spec.stride(axis);
571            let mut new_op = self.clone();
572            if new_op.pool_spec.dilation(axis) > 1 {
573                new_op.pool_spec.dilations.as_mut().unwrap()[axis] /= downsample_factor;
574            }
575            new_op.pool_spec.strides.as_mut().unwrap()[axis] /= downsample_factor;
576            let mut patch = TypedModelPatch::default();
577            let mut taps = patch.taps(model, &node.inputs)?;
578            let shape = self.pool_spec.data_format.shape(&input_fact.shape)?;
579            taps[0] = patch.wire_node(
580                format!("{}.downsample.{}", node.name, axis),
581                crate::ops::Downsample::new(axis + shape.h_axis(), downsample_factor as isize, 0),
582                &[taps[0]],
583            )?[0];
584            let id = patch.wire_node(&*node.name, new_op, &taps)?[0];
585            patch.shunt_outside(model, OutletId::new(node.id, 0), id)?;
586            return Ok(Some(patch));
587        }
588        Ok(None)
589    }
590
591    fn declutter_as_einsum(
592        &self,
593        model: &TypedModel,
594        node: &TypedNode,
595    ) -> TractResult<Option<TypedModelPatch>> {
596        let (input_facts, output_facts) = model.node_facts(node.id)?;
597        let full_input_shape = input_facts[0].shape.to_tvec();
598        let input_shape = self.pool_spec.data_format.shape(&full_input_shape)?;
599        if self.group == 1
600            && self.pool_spec.strides().iter().all(|s| *s == 1)
601            && self.pool_spec.dilations().iter().all(|d| *d == 1)
602            && self.pool_spec.kernel_shape.iter().product::<usize>() == 1
603            && self
604                .pool_spec
605                .computed_padding(input_shape.hw_dims())
606                .iter()
607                .all(|pad| pad.pad_after.is_zero() && pad.pad_before.is_zero())
608        {
609            let mut axes = self.axes_mapping(&input_facts, &output_facts)?;
610            let mut patch = TypedModelPatch::new("declutter_as_einsum");
611            let mut taps = patch.taps(model, &node.inputs)?;
612            let name = &node.name;
613            let co = self.output_channels();
614            taps[1] =
615                self.wire_kernel_as_g_o_ihw(&mut patch, &format!("{name}.filters"), taps[1])?[0];
616            taps[1] =
617                patch.wire_node(format!("{name}.filters_as_co_ci"), AxisOp::Rm(0), &[taps[1]])?[0];
618
619            while axes.rank(InOut::In(1)) > 0 {
620                axes = axes.remove_axis_occurency(InOut::In(1), 0)?;
621            }
622            axes = axes
623                .with_extra_axis_occurency('O', InOut::In(1), 0)?
624                .with_extra_axis_occurency('I', InOut::In(1), 1)?;
625
626            let bias_fact = input_facts[2];
627            let wire = if self.q_params.is_some() {
628                if bias_fact.rank() == 1 {
629                    axes = axes.linking('O', (InOut::In(2), 0))?;
630                }
631                let op = EinSum { axes, operating_dt: i32::datum_type(), q_params: self.q_params };
632                patch.wire_node(format!("{name}.einsum"), op, &taps)?[0]
633            } else {
634                axes = axes.remove_slot(InOut::In(2))?;
635                let op = EinSum { axes, operating_dt: input_facts[0].datum_type, q_params: None };
636                let mut wire = patch.wire_node(format!("{name}.einsum"), op, &taps[0..2])?[0];
637
638                if !bias_fact.konst.as_ref().map(|f| f.is_zero()).transpose()?.unwrap_or(false) {
639                    let bias_current_shape =
640                        if bias_fact.rank() == 0 { tvec!() } else { tvec!(co.to_dim()) };
641                    let mut bias_shape = tvec!(1.to_dim(); input_shape.rank());
642                    if bias_fact.rank() > 0 {
643                        bias_shape[input_shape.c_axis()] = co.to_dim();
644                    }
645                    let b = patch.wire_node(
646                        format!("{name}.bias.reshape"),
647                        AxisOp::Reshape(0, bias_current_shape, bias_shape),
648                        &[taps[2]],
649                    )?[0];
650                    wire = patch.wire_node(
651                        format!("{name}.bias"),
652                        crate::ops::math::add(),
653                        &[wire, b],
654                    )?[0];
655                }
656                wire
657            };
658            patch.node_mut(wire.node).name = node.name.to_string();
659            patch.shunt_outside(model, node.id.into(), wire)?;
660            return Ok(Some(patch));
661        }
662        Ok(None)
663    }
664
665    fn declutter_precursor_padding(
666        &self,
667        model: &TypedModel,
668        node: &TypedNode,
669    ) -> TractResult<Option<TypedModelPatch>> {
670        if matches!(self.pool_spec.padding, ExplicitOnnxPool(_, _, _) | SameLower | SameUpper) {
671            return Ok(None);
672        }
673        let prec = model.node(node.inputs[0].node);
674        let pad = if let Some(pad) = prec.op_as::<Pad>() { pad } else { return Ok(None) };
675        let value = if let PadMode::Constant(c) = &pad.mode {
676            c
677        } else {
678            return Ok(None);
679        };
680        let shape = self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
681        if !value.is_zero()?
682            || (self.pool_spec.data_format.has_n() && pad.pads[0] != (0, 0))
683            || pad.pads[shape.c_axis()] != (0, 0)
684        {
685            return Ok(None);
686        }
687        let mut before: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.0).collect();
688        let mut after: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.1).collect();
689        if let Explicit(bef, aft) = &self.pool_spec.padding {
690            izip!(&mut before, bef).for_each(|(pad, cv)| *pad += cv);
691            izip!(&mut after, aft).for_each(|(pad, cv)| *pad += cv);
692        }
693        let padding = Explicit(before, after);
694        let mut new = self.clone();
695        new.pool_spec.padding = padding;
696        let mut patch = TypedModelPatch::default();
697        let mut wire = patch.taps(model, &node.inputs)?;
698        wire[0] = patch.tap_model(model, prec.inputs[0])?;
699        let wire = patch.wire_node(&node.name, new, &wire)?;
700        patch.shunt_outside(model, node.id.into(), wire[0])?;
701        Ok(Some(patch))
702    }
703
704    fn declutter_channel_arithmetic_succ(
705        &self,
706        model: &TypedModel,
707        node: &TypedNode,
708    ) -> TractResult<Option<TypedModelPatch>> {
709        if self.q_params.is_some() || self.group != 1 {
710            return Ok(None);
711        }
712        let &[succ_outlet] = &*node.outputs[0].successors else { return Ok(None) };
713        let succ = model.node(succ_outlet.node);
714        let Some(bin) = succ.op_as::<TypedBinOp>() else { return Ok(None) };
715        let other_input = succ.inputs[1 - succ_outlet.slot];
716        let axes_mapping = model.node_axes_mapping(succ.id)?;
717        let input_shape =
718            self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
719        let conv_c_axis = input_shape.c_axis();
720        if axes_mapping.axis((InOut::In(succ_outlet.slot), conv_c_axis))?.inputs
721            [1 - succ_outlet.slot]
722            .len()
723            != 1
724        {
725            return Ok(None);
726        };
727        let mut other_expected_shape = tvec!(1.to_dim(); input_shape.rank());
728        other_expected_shape[conv_c_axis] = self.output_channels().to_dim();
729        if *other_expected_shape != *model.outlet_fact(other_input)?.shape {
730            return Ok(None);
731        }
732
733        let mut patch = TypedModelPatch::default();
734        let [input, mut kernel, mut bias] = &*patch.taps(model, &node.inputs)? else {
735            panic!("Expect three inputs");
736        };
737        let name = &node.name;
738        let succ_name = &succ.name;
739
740        let operand = patch.tap_model(model, other_input)?;
741
742        let renamed_bias = format!("{name}.{succ_name}.bias");
743        let renamed_kernel = format!("{name}.{succ_name}.kernel");
744        bias = wire_reshape_bias_for_bin(
745            &mut patch,
746            format!("{renamed_bias}.reshape"),
747            bias,
748            1,
749            0,
750            self.output_channels(),
751        )?[0];
752
753        let operand = wire_reshape_bias_for_bin(
754            &mut patch,
755            format!("{renamed_bias}.reshape_operand"),
756            operand,
757            1,
758            0,
759            self.output_channels(),
760        )?[0];
761
762        let operand_fact = patch.outlet_fact(operand)?.shape.to_tvec();
763        let kernel_fact = patch.outlet_fact(kernel)?;
764        let mut operand_shape_for_kernel = tvec!(1.to_dim(); 2 + input_shape.hw_rank());
765        operand_shape_for_kernel[self.kernel_fmt.o_axis(&kernel_fact.shape)] =
766            self.output_channels().to_dim();
767        let operand_for_kernel = patch.wire_node(
768            format!("{renamed_kernel}.reshape_operand"),
769            AxisOp::Reshape(0, operand_fact, operand_shape_for_kernel),
770            &[operand],
771        )?[0];
772
773        if bin.0.is::<Sub>() && succ_outlet.slot == 0 {
774            bias = patch.wire_node(&renamed_bias, sub(), &[bias, operand])?[0];
775        } else if bin.0.is::<Sub>() {
776            bias = patch.wire_node(&renamed_bias, sub(), &[operand, bias])?[0];
777        } else if bin.0.is::<Div>() && succ_outlet.slot == 0 {
778            bias = patch.wire_node(&renamed_bias, div(), &[bias, operand])?[0];
779            kernel = patch.wire_node(&renamed_kernel, div(), &[kernel, operand_for_kernel])?[0];
780        } else if bin.0.is::<Div>() {
781            bias = patch.wire_node(&renamed_bias, div(), &[operand, bias])?[0];
782            kernel = patch.wire_node(&renamed_kernel, div(), &[operand_for_kernel, kernel])?[0];
783        } else if bin.0.is::<Add>() {
784            bias = patch.wire_node(&renamed_bias, add(), &[bias, operand])?[0];
785        } else if bin.0.is::<Mul>() {
786            bias = patch.wire_node(&renamed_bias, mul(), &[bias, operand])?[0];
787            kernel = patch.wire_node(&renamed_kernel, mul(), &[kernel, operand_for_kernel])?[0];
788        } else {
789            return Ok(None);
790        };
791        let wire = patch.wire_node(&node.name, self.clone(), &[*input, kernel, bias])?[0];
792        patch.shunt_outside(model, succ_outlet.node.into(), wire)?;
793        Ok(Some(patch))
794    }
795}
796
797impl Op for Conv {
798    fn name(&self) -> Cow<str> {
799        "Conv".into()
800    }
801
802    fn info(&self) -> TractResult<Vec<String>> {
803        let mut info = self.pool_spec.info();
804        info.push(format!("Kernel {:?} (groups:{})", self.kernel_fmt, self.group));
805        Ok(info)
806    }
807
808    fn validation(&self) -> Validation {
809        Validation::Rounding
810    }
811
812    op_as_typed_op!();
813}
814
815impl EvalOp for Conv {
816    fn is_stateless(&self) -> bool {
817        true
818    }
819
820    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
821        let mut model = TypedModel::default();
822        let wire: TVec<OutletId> = inputs
823            .iter()
824            .enumerate()
825            .map(|(ix, v)| model.add_source(format!("source.{ix}"), v.datum_type().fact(v.shape())))
826            .collect::<TractResult<_>>()?;
827        let wire = unsafe {
828            if self.q_params.is_some() {
829                self.wire_as_quant_im2col(&mut model, "im2col-adhoc", &wire)?
830            } else {
831                self.wire_as_im2col_pair(&mut model, "im2col-adhoc", &wire)?
832            }
833        };
834        model.set_output_outlets(&wire)?;
835        model.into_runnable()?.run(inputs)
836    }
837}
838
839impl TypedOp for Conv {
840    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
841        ensure!(self.q_params.is_some() || inputs[0].datum_type.is_float());
842        let q_inputs = if self.q_params.is_some() { 6 } else { 0 };
843        if inputs.len() != 3 + q_inputs {
844            bail!("Wrong number of inputs: expected {} got {}", 3 + q_inputs, inputs.len());
845        }
846        if self.q_params.is_some() {
847            ensure!(inputs[2].datum_type == i32::datum_type());
848            ensure!(inputs[3].datum_type == i32::datum_type());
849            ensure!(inputs[4].datum_type.is_float());
850            ensure!(inputs[5].datum_type == i32::datum_type());
851            ensure!(inputs[6].datum_type.is_float());
852            ensure!(inputs[7].datum_type == i32::datum_type());
853            ensure!(inputs[8].datum_type.is_float());
854        }
855        ensure!(self.pool_spec.rank() + 2 == inputs[1].rank());
856        if self.pool_spec.data_format.shape(&*inputs[0].shape)?.c()
857            != &self.input_channels().to_dim()
858        {
859            bail!(
860                    "Inconsistent convolution: input is {:?}, but kernel expects {} input channels.\n{:?}",
861                    inputs[0],
862                    self.input_channels(),
863                    self
864                    );
865        }
866        if let ExplicitOnnxPool(bef, after, _) | Explicit(bef, after) = &self.pool_spec.padding {
867            anyhow::ensure!(bef.len() == self.pool_spec.rank());
868            anyhow::ensure!(after.len() == self.pool_spec.rank());
869        }
870        ensure!(
871            inputs[2].rank() == 0
872            || (inputs[2].rank() == 1
873                && inputs[2].shape.volume() == self.output_channels().to_dim()),
874                "Bias should be scalar or a vector with one value per output channel. Output channels is {}, bias is {:?}",
875                self.output_channels(),
876                inputs[2]
877               );
878        let mut fact = self.pool_spec.output_facts(inputs)?.remove(0);
879        if let Some(dt) = self.q_params {
880            fact.datum_type = dt;
881        } else {
882            ensure!(
883                inputs[0].datum_type == inputs[1].datum_type,
884                "Convolution input, weights and bias must have the same type, got {inputs:?}",
885            )
886        }
887        Ok(tvec!(fact))
888    }
889
890    fn axes_mapping(
891        &self,
892        inputs: &[&TypedFact],
893        outputs: &[&TypedFact],
894    ) -> TractResult<AxesMapping> {
895        let fact = &inputs[0];
896        let shape = self.pool_spec.data_format.shape(&fact.shape)?;
897        let mut axes = AxesMapping::disconnected(inputs, outputs)?
898            .renaming((InOut::In(0), shape.c_axis()), 'I')?
899            .renaming((InOut::Out(0), shape.c_axis()), 'O')?;
900        if let Some(n_axis) = shape.n_axis() {
901            axes = axes
902                .renaming((InOut::In(0), n_axis), 'N')?
903                .linking('N', (InOut::Out(0), n_axis))?;
904        }
905        let h_axis = shape.h_axis();
906        let geo = "HWXYZ".chars().chain('a'..);
907        let kernel_spatial_shape = &self.pool_spec.kernel_shape;
908        let padding = self.pool_spec.computed_padding(shape.hw_dims());
909        for ((ix, &dim), repr) in kernel_spatial_shape.iter().enumerate().zip(geo) {
910            if dim == 1
911                && self.pool_spec.dilation(ix) == 1
912                && self.pool_spec.stride(ix) == 1
913                && padding[ix].pad_before.is_zero()
914                && padding[ix].pad_after.is_zero()
915            {
916                axes = axes
917                    .renaming((InOut::In(0), ix + h_axis), repr)?
918                    .linking(repr, (InOut::Out(0), ix + h_axis))?;
919            }
920        }
921        if self.q_params.is_some() {
922            for (qp_ix, qp) in inputs.iter().enumerate().skip(3) {
923                if qp.rank() == 1 {
924                    axes = match qp_ix {
925                        3 | 4 => axes.linking('I', (InOut::In(qp_ix), 0))?,
926                        5 | 6 => axes.linking('O', (InOut::In(qp_ix), 0))?,
927                        7 | 8 => axes.linking('O', (InOut::In(qp_ix), 0))?,
928                        _ => unreachable!(),
929                    };
930                }
931            }
932        }
933        Ok(axes)
934    }
935
936    fn declutter(
937        &self,
938        model: &TypedModel,
939        node: &TypedNode,
940    ) -> TractResult<Option<TypedModelPatch>> {
941        macro_rules! pass {
942            ($func:ident) => {
943                if let Some(mut r) = self.$func(model, node).context(stringify!($func))? {
944                    trace!(stringify!($func));
945                    r.push_context(stringify!($func));
946                    return Ok(Some(r));
947                }
948            };
949        }
950        pass!(declutter_stride_slice_to_downsample);
951        pass!(declutter_as_einsum);
952        pass!(declutter_channel_arithmetic_succ);
953        pass!(declutter_precursor_padding);
954        Ok(None)
955    }
956
957    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
958        let shape = self.pool_spec.data_format.shape(inputs[0].shape.to_tvec())?;
959        let kernel_spatial_shape = &self.pool_spec.kernel_shape;
960        let output_dims = self.pool_spec.padding.compute(
961            shape.hw_dims(),
962            kernel_spatial_shape,
963            &self
964                .pool_spec
965                .dilations
966                .clone()
967                .unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
968            &self.pool_spec.strides.clone().unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
969        );
970        let n_output_points: TDim =
971            output_dims.iter().map(|d| d.convoluted.clone()).product::<TDim>();
972        let n_output_channels = self.output_channels().to_dim();
973        let kernel_surface = kernel_spatial_shape.iter().product::<usize>().to_dim();
974        let one = 1.to_dim();
975        Ok(tvec!((
976            Cost::FMA(inputs[0].datum_type),
977            shape.n().cloned().unwrap_or(one)
978                * shape.c()
979                * n_output_channels
980                * n_output_points
981                * kernel_surface
982                / self.group
983        )))
984    }
985
986    fn change_axes(
987        &self,
988        model: &TypedModel,
989        node: &TypedNode,
990        io: InOut,
991        change: &AxisOp,
992    ) -> TractResult<Option<AxisChangeConsequence>> {
993        if io == InOut::In(1) {
994            return Ok(None);
995        }
996        if io == InOut::In(2) {
997            if let &AxisOp::Rm(_) = change {
998                return Ok(Some(AxisChangeConsequence {
999                    substitute_op: Some(Box::new(self.clone())),
1000                    wire_changes: tvec!(),
1001                }));
1002            }
1003        }
1004        let full_input_shape = model.outlet_fact(node.inputs[0])?.shape.to_tvec();
1005        let shape = self.pool_spec.data_format.shape(full_input_shape.clone())?;
1006        // remove n
1007        if let Some(n) = shape.n_axis() {
1008            assert_eq!(n, 0);
1009            if change == &AxisOp::Rm(n) {
1010                let op = Conv { pool_spec: self.pool_spec.dispose_n_axis(), ..self.clone() };
1011                return Ok(Some(AxisChangeConsequence {
1012                    substitute_op: Some(Box::new(op)),
1013                    wire_changes: tvec!(
1014                        (InOut::In(0), change.clone()),
1015                        (InOut::Out(0), change.clone())
1016                    ),
1017                }));
1018            }
1019            if change.transform_axis(n).map(|axis| axis > 0).unwrap_or(true) {
1020                return Ok(None);
1021            }
1022        }
1023        // format swap: chw <-> hwc
1024        let (new_format, axis_move) = match self.pool_spec.data_format {
1025            DataFormat::NCHW => {
1026                (DataFormat::NHWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
1027            }
1028            DataFormat::CHW => {
1029                (DataFormat::HWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
1030            }
1031            DataFormat::NHWC => (DataFormat::NCHW, AxisOp::Move(shape.c_axis(), 1)),
1032            DataFormat::HWC => (DataFormat::CHW, AxisOp::Move(shape.c_axis(), 0)),
1033        };
1034        if *change == axis_move {
1035            let mut new_op = self.clone();
1036            new_op.pool_spec.data_format = new_format;
1037            return Ok(Some(AxisChangeConsequence {
1038                substitute_op: Some(Box::new(new_op)),
1039                wire_changes: tvec!(
1040                    (InOut::In(0), change.clone()),
1041                    (InOut::Out(0), change.clone())
1042                ),
1043            }));
1044        }
1045        // geo axis manips
1046        use AxisOp::*;
1047        let h_axis = shape.h_axis();
1048        let hw_axes = shape.hw_axes();
1049        let kh_axis = self.kernel_fmt.h_axis();
1050        let (geo_adjusted, kernel_adjusted) = match change {
1051            Rm(a)
1052                if hw_axes.contains(a)
1053                    && hw_axes.len() > 1
1054                    && self.pool_spec.dilation(a - h_axis) == 1
1055                    && self.pool_spec.stride(a - h_axis) == 1
1056                    && self.pool_spec.kernel_shape[a - h_axis] == 1 =>
1057            {
1058                let geo_axis = a - h_axis;
1059                (Rm(geo_axis), Rm(kh_axis + geo_axis))
1060            }
1061            Add(a) if hw_axes.contains(a) => (Add(a - h_axis), Add(a - h_axis + kh_axis)),
1062            Move(f, t) if hw_axes.contains(f) && hw_axes.contains(t) => {
1063                (Move(f - h_axis, t - h_axis), Move(f - h_axis + kh_axis, t - h_axis + kh_axis))
1064            }
1065            _ => return Ok(None),
1066        };
1067        let pool_spec = self.pool_spec.change_geo_axes(&geo_adjusted)?;
1068        let new_op = Conv { pool_spec, ..self.clone() };
1069        Ok(Some(AxisChangeConsequence {
1070            substitute_op: Some(Box::new(new_op)),
1071            wire_changes: tvec!(
1072                (InOut::In(0), change.clone()),
1073                (InOut::In(1), kernel_adjusted),
1074                (InOut::Out(0), change.clone())
1075            ),
1076        }))
1077    }
1078
1079    fn codegen(
1080        &self,
1081        model: &TypedModel,
1082        node: &TypedNode,
1083    ) -> TractResult<Option<TypedModelPatch>> {
1084        let input_fact = model.outlet_fact(node.inputs[0])?;
1085        unsafe {
1086            if self.q_params.is_some() {
1087                let mut patch = TypedModelPatch::default();
1088                let inputs = patch.taps(model, &node.inputs)?;
1089                let wire = self
1090                    .wire_as_quant_im2col(&mut patch, &node.name, &inputs)
1091                    .context("in wire_as_quant_im2col")?;
1092                patch.shunt_outside(model, node.id.into(), wire[0])?;
1093                patch.obliterate(node.id)?;
1094                Ok(Some(patch.with_context("quantized-codegen")))
1095            } else if input_fact
1096                .shape
1097                .as_concrete()
1098                .map(|s| {
1099                    should_use_lazy(
1100                        &self.pool_spec.data_format.shape(s.into()).unwrap(),
1101                        &self.pool_spec,
1102                        self.group,
1103                    )
1104                })
1105                .unwrap_or(false)
1106            {
1107                let mut patch = TypedModelPatch::new("wire_as_lazy_im2col");
1108                let inputs = patch.taps(model, &node.inputs)?;
1109                let wire = self
1110                    .wire_as_lazy_im2col(&mut patch, &node.name, &inputs)
1111                    .context("wire_as_lazy_im2col")?[0];
1112                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1113                patch.obliterate(node.id)?;
1114                Ok(Some(patch))
1115            } else if self.group != 1
1116                && self.group == self.output_channels()
1117                && self.group == self.input_channels()
1118                && input_fact.shape.as_concrete().is_some()
1119            {
1120                let mut patch = TypedModelPatch::default();
1121                let inputs = patch.taps(model, &node.inputs)?;
1122                let wire = self
1123                    .wire_as_depth_wise(&mut patch, &node.name, &inputs)
1124                    .context("wire_as_depth_wise")?;
1125                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1126                patch.obliterate(node.id)?;
1127                Ok(Some(patch))
1128            } else {
1129                let mut patch = TypedModelPatch::default();
1130                let inputs = patch.taps(model, &node.inputs)?;
1131                let wire = self
1132                    .wire_as_im2col_pair(&mut patch, &node.name, &inputs)
1133                    .context("in wire_as_im2col_pair")?[0];
1134                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1135                patch.obliterate(node.id)?;
1136                Ok(Some(patch))
1137            }
1138        }
1139    }
1140
1141    as_op!();
1142}
1143
1144fn should_use_lazy(input_shape: &DataShape, pool_spec: &PoolSpec, group: usize) -> bool {
1145    input_shape.n().unwrap_or(&1) == &1
1146        && group == 1
1147        && pool_spec.kernel_shape.iter().product::<usize>() > 5
1148}
1149
1150#[allow(non_snake_case)]
1151#[cfg(test)]
1152mod test {
1153    use super::*;
1154    use crate::ops::array::Pad;
1155    use DataFormat::*;
1156
1157    #[test]
1158    fn onnx_basic_convinteger() {
1159        let op = Conv {
1160            pool_spec: PoolSpec {
1161                data_format: NCHW,
1162                kernel_shape: tvec!(2, 2),
1163                padding: Valid,
1164                dilations: None,
1165                strides: None,
1166                input_channels: 1,
1167                output_channels: 1,
1168            },
1169            kernel_fmt: KernelFormat::OIHW,
1170            group: 1,
1171            q_params: Some(i32::datum_type()),
1172        };
1173        let input = tvec!(
1174            rctensor4(&[[[[1u8, 2, 3], [4, 5, 6], [7, 8, 9]]]]),
1175            rctensor4(&[[[[1u8, 1], [1, 1]]]]),
1176            rctensor0(0u32),
1177            rctensor0(1u8),
1178            rctensor0(1.0f32),
1179            rctensor0(0u8),
1180            rctensor0(1.0f32),
1181            rctensor0(0i32),
1182            rctensor0(1.0f32),
1183        );
1184        let input = input.into_iter().map(IntoTValue::into_tvalue).collect::<TVec<_>>();
1185        let output = op.eval(input).unwrap();
1186        assert_eq!(*output[0], tensor4(&[[[[8i32, 12], [20, 24]]]]));
1187    }
1188
1189    #[test]
1190    fn valid_conv_absorbs_precursor_pad() -> TractResult<()> {
1191        let mut model = TypedModel::default();
1192        let wire = tvec!(model.add_source("source", f32::fact(dims!(1, 10)))?);
1193        let wire = model.wire_node(
1194            "pad",
1195            Pad {
1196                pads: vec![(0, 0), (1, 0)],
1197                mode: ops::array::PadMode::Constant(rctensor0(0f32)),
1198            },
1199            &wire,
1200        )?;
1201        let kernel = model.add_const("kernel", rctensor3(&[[[1f32, 2f32]]]))?;
1202        let bias = model.add_const("bias", rctensor0(0f32))?;
1203        let wire = model.wire_node(
1204            "conv",
1205            Conv {
1206                pool_spec: PoolSpec {
1207                    data_format: crate::ops::nn::DataFormat::CHW,
1208                    dilations: None,
1209                    strides: None,
1210                    kernel_shape: tvec![2],
1211                    padding: Explicit(tvec![0], tvec![0]),
1212                    input_channels: 1,
1213                    output_channels: 1,
1214                },
1215                kernel_fmt: crate::ops::cnn::KernelFormat::OIHW,
1216                group: 1,
1217                q_params: None,
1218            },
1219            &[wire[0], kernel, bias],
1220        )?;
1221        model.set_output_outlets(&wire)?;
1222        model.declutter()?;
1223        assert_eq!(model.nodes().len(), 4); // source + conv + kernel + bias
1224        let cv = model.nodes()[3].op_as::<Conv>().unwrap();
1225        assert_eq!(cv.pool_spec.padding, Explicit(tvec![1], tvec![0])); // source + conv
1226        Ok(())
1227    }
1228}