tract_core/ops/cnn/conv/
conv.rs

1use tract_data::itertools::izip;
2use tract_linalg::block_quant::{BlockQuantFact, PackedBlockQuantFormat};
3use tract_linalg::WeightType;
4use tract_num_traits::Zero;
5
6use crate::internal::*;
7use crate::model::*;
8use crate::ops;
9use crate::ops::array::Pad;
10use crate::ops::array::PadMode;
11use crate::ops::binary::TypedBinOp;
12use crate::ops::cast::cast;
13use crate::ops::cnn::conv::block_quant::{BlockQuantIntoShape, SplitGroupBlockQuant};
14use crate::ops::cnn::conv::lazy_im2col::LazyIm2Col;
15use crate::ops::cnn::conv::lazy_im2col::LazyIm2colParams;
16use crate::ops::cnn::wire_reshape_bias_for_bin;
17use crate::ops::cnn::PaddingSpec::*;
18use crate::ops::einsum::EinSum;
19use crate::ops::math::{add, div, mul, sub};
20use crate::ops::math::{Add, Div, Mul, Sub};
21use crate::ops::matmul::optimized::AddMatMulGeometry;
22use crate::ops::matmul::optimized::MapOutputAxisToInput;
23use crate::ops::matmul::pack::{OptMatMulPack, OptSimpleMatMulPack};
24use crate::ops::matmul::quant::wire_ensure_q8_flavour;
25use crate::ops::matmul::ModePicker;
26use crate::ops::nn::Reduce;
27
28use super::depth_wise::DepthWise;
29use super::im2col::Im2Col;
30use crate::ops::cnn::conv::{block_quant_aware_weight_shape, KernelFormat};
31use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry, PoolSpec};
32use crate::ops::matmul::optimized::{OptMatMul, ProtoFusedSpec};
33use crate::ops::nn::{BaseDataShape, DataFormat, DataShape};
34
35use tract_linalg::mmm::{MMMInputFormat, MatMatMul};
36use tract_linalg::pack::PackedFormat;
37
38#[derive(Debug, Clone, new, Hash)]
39pub struct Conv {
40    pub pool_spec: PoolSpec,
41    pub kernel_fmt: KernelFormat,
42    pub group: usize,
43    // None -> floats
44    // Some(I32) -> output is I32 (use quantized kernels, but output will be i32). last 2 Q inputs
45    // are ignored
46    // Some(QXX) -> quantized XX, but parameters are ignored (I8, U8, or I32) in favor of last 2 Q inputs
47    pub q_params: Option<DatumType>,
48}
49
50impl Conv {
51    pub fn input_channels(&self) -> usize {
52        self.pool_spec.input_channels
53    }
54
55    pub fn output_channels(&self) -> usize {
56        self.pool_spec.output_channels
57    }
58
59    pub fn wire_kernel_as_g_o_ihw(
60        &self,
61        model: &mut TypedModel,
62        name: &str,
63        mut kernel: OutletId,
64    ) -> TractResult<TVec<OutletId>> {
65        let fact = model.outlet_fact(kernel)?;
66        if fact.datum_type.is_opaque() {
67            ensure!(self.kernel_fmt == KernelFormat::OIHW && fact.rank() == 0);
68            kernel = model.wire_node(
69                format!("{name}.prep_kernel.g"),
70                SplitGroupBlockQuant { group: self.group },
71                &[kernel],
72            )?[0];
73            kernel = model.wire_node(
74                format!("{name}.prep_kernel.ihw"),
75                BlockQuantIntoShape {
76                    shape: tvec!(
77                        self.output_channels() / self.group,
78                        self.input_channels() / self.group
79                            * self.pool_spec.kernel_shape.iter().product::<usize>(),
80                    ),
81                },
82                &[kernel],
83            )?[0];
84            Ok(tvec!(kernel))
85        } else {
86            for (ix, op) in self
87                .kernel_fmt
88                .kernel_as_group_o_ihw_ops(&fact.shape, self.group)
89                .into_iter()
90                .enumerate()
91            {
92                kernel = model.wire_node(format!("{name}.prep_kernel.{ix}"), op, &[kernel])?[0];
93            }
94            Ok(tvec!(kernel))
95        }
96    }
97
98    fn wire_pack_g_o_ihw(
99        &self,
100        model: &mut TypedModel,
101        name: &str,
102        format: &dyn MMMInputFormat,
103        kernel: OutletId,
104    ) -> TractResult<OutletId> {
105        let fact = model.outlet_fact(kernel)?;
106        let wire = if fact.datum_type.is_opaque() {
107            let fact = model
108                .outlet_fact(kernel)?
109                .opaque_fact
110                .as_ref()
111                .and_then(|of| of.downcast_ref::<BlockQuantFact>())
112                .context("Only manage BlockQuant")?;
113            model.wire_node(
114                format!("{name}.prep_kernel.pack"),
115                OptSimpleMatMulPack {
116                    packed_format: format
117                        .downcast_ref::<PackedBlockQuantFormat>()
118                        .context("Expect a block quant format")?
119                        .clone(),
120                    k: fact.k(),
121                    m: fact.m(),
122                },
123                &[kernel],
124            )?
125        } else {
126            let format = format
127                .downcast_ref::<PackedFormat>()
128                .context("Expect regular packing for numeric weights")?;
129            model.wire_node(
130                format!("{name}.prep_kernel.pack"),
131                OptMatMulPack {
132                    packers: vec![format.clone()],
133                    k_axis: 2,
134                    mn_axis: 1,
135                    mode_picker: ModePicker::Single,
136                },
137                &[kernel],
138            )?
139        };
140        Ok(wire[0])
141    }
142
143    // group,bias
144    fn wire_bias_as_non_linear(
145        &self,
146        model: &mut TypedModel,
147        name: &str,
148        bias: OutletId,
149        c_group_axis: usize,
150    ) -> TractResult<(ProtoFusedSpec, OutletId)> {
151        use tract_linalg::BinOp::Add;
152        let fact = model.outlet_fact(bias)?;
153        if fact.shape.volume().is_one() {
154            Ok((ProtoFusedSpec::BinScalar(2, Add), bias))
155        } else {
156            let bias = AxisOp::wire_split_axis(
157                model,
158                format!("{name}.reformat_bias"),
159                bias,
160                0,
161                self.group,
162            )?[0];
163            let pfs =
164                ProtoFusedSpec::BinPerRow(2, Add, MapOutputAxisToInput(tvec!((c_group_axis, 0))));
165            Ok((pfs, bias))
166        }
167    }
168
169    pub unsafe fn wire_as_quant_im2col(
170        &self,
171        model: &mut TypedModel,
172        name: &str,
173        wires: &[OutletId],
174    ) -> TractResult<TVec<OutletId>> {
175        ensure!(self.q_params.is_some());
176        use crate::ops::matmul::quant as qmm;
177
178        let c_dt = self.q_params.unwrap();
179        let &[mut x, mut kernel, bias, mut x0, x_scale, mut k0, mut k_scale, y0, y_scale] = wires
180        else {
181            bail!("Wrong number of inputs")
182        };
183        wire_ensure_q8_flavour(model, name, &mut kernel, "k", &mut k0, i8::datum_type())?;
184        wire_ensure_q8_flavour(model, name, &mut x, "x", &mut x0, i8::datum_type())?;
185
186        let a_fact = model.outlet_fact(kernel)?.clone();
187        let b_fact = model.outlet_fact(x)?.clone();
188
189        let (_geo, m, k, n) = self.compute_geo(&b_fact)?;
190        let (mmm, packing) = self.choose_impl(&b_fact, &a_fact, m, k, &n)?;
191        let output_shape = self.pool_spec.output_shape(&b_fact.shape)?;
192
193        if !model.outlet_fact(k_scale)?.shape.volume().is_one() {
194            // requant is performed before geo_reshape, so we need at most one geo axis to the
195            // right
196            if !output_shape.fmt.c_is_last() {
197                k_scale = model.wire_node(
198                    format!("{name}.a_scale_axis_fix"),
199                    AxisOp::Add(1),
200                    &[k_scale],
201                )?[0];
202            }
203        }
204
205        let abc_scale = qmm::combine_scales(model, name, k_scale, x_scale, y_scale)?;
206
207        let im2col = model.wire_node(
208            format!("{name}.im2col"),
209            Im2Col::new(
210                self.pool_spec.clone(),
211                self.group,
212                k,
213                &b_fact.shape,
214                mmm.clone(),
215                packing,
216            )?,
217            &[x, x0],
218        )?[0];
219
220        let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
221        let g_o_ihw_as_i32 =
222            model.wire_node(format!("{name}.kernel_as_i32"), cast(i32::datum_type()), &g_o_ihw)?;
223        let sum_ker_g_c_k = model.wire_node(
224            format!("{name}.sum_ker_g_c_k"),
225            Reduce::new(tvec!(2), ops::nn::Reducer::Sum),
226            &g_o_ihw_as_i32,
227        )?;
228        let sum_ker_a_g_c =
229            model.wire_node(format!("{name}.rm_k"), AxisOp::Rm(2), &sum_ker_g_c_k)?;
230        // align sum_A from G,C to "C" shape: N,HW,G,C (or N,G,C,HW)
231        let sum_ker_n_g_c = model.wire_node(
232            format!("{name}.sum_ker_n_g_c.axis_0"),
233            AxisOp::Add(0),
234            &sum_ker_a_g_c,
235        )?;
236        let hw_position = if self.pool_spec.data_format.c_is_last() { 1 } else { 3 };
237        let sum_ker = model.wire_node(
238            format!("{name}.sum_ker_n_g_c"),
239            AxisOp::Add(hw_position),
240            &sum_ker_n_g_c,
241        )?;
242
243        ensure!(mmm.packings()[packing].1.downcast_ref::<PackedFormat>().is_some());
244        let mut sum_x = model.wire_node(
245            format!("{name}.sum_x"),
246            super::QSumB { dt: b_fact.datum_type, n, r: mmm.nr(), k },
247            &[im2col],
248        )?;
249        // sum_b is N,G,HW. make it N,HW,G,C or N,G,C,HW
250        sum_x = model.wire_node(format!("{name}.add_c"), AxisOp::Add(2), &sum_x)?;
251        if self.pool_spec.data_format.c_is_last() {
252            sum_x =
253                model.wire_node(format!("{name}.transpose_sum_b"), AxisOp::Move(3, 1), &sum_x)?;
254        }
255
256        let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?;
257        let bias_name = &model.node(bias.node).name;
258        let bias =
259            model.wire_node(format!("{bias_name}.cast"), cast(mmm.internal_type()), &[bias])?[0];
260        let wire = self.wire_mm_weights_bias(
261            model,
262            name,
263            im2col,
264            g_o_ihw[0],
265            bias,
266            mmm,
267            packing,
268            i32::datum_type(),
269            mmm_output_shape.clone().into(),
270            k,
271            c_axis,
272            h_axis,
273        )?;
274
275        let wire = qmm::compensate_zero_points(
276            model,
277            name,
278            wire[0],
279            k.to_dim(),
280            k0,
281            x0,
282            sum_ker[0],
283            sum_x[0],
284        )?;
285
286        let wire = self.wire_remove_group(model, name, &[wire], &mmm_output_shape, c_axis)?;
287        let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
288        let wire = qmm::requant(model, name, wire[0], c_dt, abc_scale, y0)?;
289        Self::wire_geo_reshape(model, name, &[wire], &output_shape)
290    }
291
292    pub fn wire_remove_group<D: DimLike>(
293        &self,
294        model: &mut TypedModel,
295        name: &str,
296        wire: &[OutletId],
297        mmm_output_shape: &[D],
298        c_axis: usize,
299    ) -> TractResult<TVec<OutletId>> {
300        let m = &mmm_output_shape[c_axis];
301        let op = if self.group == 1 {
302            AxisOp::Rm(c_axis - 1)
303        } else {
304            AxisOp::Reshape(
305                c_axis - 1,
306                tvec!(self.group.to_dim(), m.to_dim()),
307                tvec!(m.to_dim() * self.group),
308            )
309        };
310        model.wire_node(format!("{name}.reshape_group"), op, wire)
311    }
312
313    pub unsafe fn wire_as_im2col_pair(
314        &self,
315        model: &mut TypedModel,
316        name: &str,
317        wire: &[OutletId],
318    ) -> TractResult<TVec<OutletId>> {
319        let &[x, w, bias] = wire else { bail!("Wrong number of inputs") };
320        let x_fact = model.outlet_fact(x)?.clone();
321        let w_fact = model.outlet_fact(w)?.clone();
322        let c_dt = crate::ops::matmul::output_type(x_fact.datum_type);
323
324        let (_, m, k, n) = self.compute_geo(&x_fact)?;
325        let (mmm, packing) = self.choose_impl(&x_fact, &w_fact, m, k, &n)?;
326        let geo_output_shape = self.pool_spec.output_shape(&x_fact.shape)?;
327        let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo_output_shape)?;
328
329        let padding =
330            model.add_const(format!("{name}.b0"), Tensor::zero_scalar_dt(x_fact.datum_type)?)?;
331
332        let mut wire: TVec<_> = wire.into();
333        wire[0] = model.wire_node(
334            format!("{name}.im2col"),
335            Im2Col::new(
336                self.pool_spec.clone(),
337                self.group,
338                k,
339                &x_fact.shape,
340                mmm.clone(),
341                packing,
342            )?,
343            &[wire[0], padding],
344        )?[0];
345
346        let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, wire[1])?;
347
348        let wire = self
349            .wire_mm_weights_bias(
350                model,
351                name,
352                wire[0],
353                g_o_ihw[0],
354                bias,
355                mmm,
356                packing,
357                c_dt,
358                mmm_output_shape.clone().into(),
359                k.to_usize().unwrap(),
360                c_axis,
361                h_axis,
362            )
363            .context("in wire_opt_matmul")?;
364
365        let wire = self.wire_remove_group(model, name, &wire, &mmm_output_shape, c_axis)?;
366        let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
367        Self::wire_geo_reshape(model, name, &wire, &geo_output_shape)
368    }
369
370    // always have N and G. G is right before C, c_axis point to C, c_axis-1 points to G
371    fn mmm_output_shape<D: DimLike>(
372        &self,
373        output_shape: &BaseDataShape<D, TVec<D>>,
374    ) -> TractResult<(TVec<D>, usize, usize)> {
375        let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
376        let shape: BaseDataShape<D, TVec<D>> = output_shape.fmt.with_n().from_n_c_hw(
377            output_shape.n().cloned().unwrap_or_else(|| 1.into()),
378            output_shape.c().clone(),
379            tvec!(geo_collapsed_out),
380        )?;
381        let mut mmm_output_shape: TVec<D> = shape.shape.clone();
382        let mut c_axis = shape.c_axis();
383        let mut h_axis = shape.h_axis();
384        mmm_output_shape[shape.c_axis()] = mmm_output_shape[c_axis].clone() / self.group;
385        mmm_output_shape.insert(c_axis, self.group.into());
386        if h_axis > c_axis {
387            h_axis += 1;
388        }
389        c_axis += 1;
390        Ok((mmm_output_shape, c_axis, h_axis))
391    }
392
393    fn wire_rm_n_if_needed(
394        &self,
395        model: &mut TypedModel,
396        name: &str,
397        wire: &[OutletId],
398    ) -> TractResult<TVec<OutletId>> {
399        if self.pool_spec.data_format.has_n() {
400            Ok(wire.into())
401        } else {
402            model.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), wire)
403        }
404    }
405
406    fn wire_geo_reshape<D: DimLike>(
407        model: &mut TypedModel,
408        name: &str,
409        wire: &[OutletId],
410        output_shape: &BaseDataShape<D, TVec<D>>,
411    ) -> TractResult<TVec<OutletId>> {
412        let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
413        model
414            .wire_node(
415                name,
416                AxisOp::Reshape(
417                    output_shape.h_axis(),
418                    tvec!(geo_collapsed_out.to_dim()),
419                    output_shape.hw_dims().iter().map(|d| d.to_dim()).collect(),
420                ),
421                wire,
422            )
423            .context("in wire_geo_reshape")
424    }
425
426    pub unsafe fn wire_as_lazy_im2col(
427        &self,
428        model: &mut TypedModel,
429        name: &str,
430        wire: &[OutletId],
431    ) -> TractResult<TVec<OutletId>> {
432        let &[mut x, kernel, bias] = wire else { bail!("Wrong number of inputs") };
433        let mut x_fact = model.outlet_fact(x)?.clone();
434        let w_fact = model.outlet_fact(kernel)?.clone();
435        let (geo, m, k, n) = self.compute_geo(&x_fact)?;
436        let (mmm, packing) = self.choose_impl(&x_fact, &w_fact, m, k, &n)?;
437        debug!("{name} as lazy_im2col: m={m} k={k} n={n} {mmm:?}");
438        let input_shape = x_fact.shape.as_concrete().unwrap().to_vec();
439        let mut geo = geo.to_concrete(&input_shape)?.into_owned();
440        let mut input_shape: DataShape = self.pool_spec.data_format.shape(input_shape.into())?;
441        let padding = self.pool_spec.computed_padding(input_shape.hw_dims());
442        if padding.iter().any(|axis| axis.pad_before != 0 || axis.pad_after != 0) {
443            let mut pads = vec![(0, 0); x_fact.rank()];
444            for (ix, ax) in padding.iter().enumerate() {
445                pads[input_shape.h_axis() + ix] = (ax.pad_before, ax.pad_after);
446            }
447            let op = crate::ops::array::Pad {
448                mode: crate::ops::array::PadMode::Constant(
449                    Tensor::zero_scalar_dt(x_fact.datum_type)?.into_arc_tensor(),
450                ),
451                pads,
452            };
453            x = model.wire_node(format!("{name}.pad"), op, &[x])?[0];
454            let valid_pool_spec = PoolSpec { padding: Valid, ..self.pool_spec.clone() };
455            x_fact = model.outlet_fact(x)?.clone();
456            let concrete_shape = x_fact.shape.as_concrete().unwrap();
457            input_shape = valid_pool_spec.data_format.shape(concrete_shape.into())?;
458            geo = valid_pool_spec
459                .compute_geo(&x_fact.shape)?
460                .to_concrete(concrete_shape)?
461                .into_owned();
462        }
463        let c_dt = crate::ops::matmul::output_type(x_fact.datum_type);
464        let c_stride = input_shape.c_stride();
465        let size_of_b = x_fact.datum_type.size_of() as isize;
466        let n_byte_offsets: Vec<isize> =
467            geo.patch.centers_offsets().into_iter().map(|x| x * size_of_b).collect();
468        let k_byte_offsets: Vec<isize> = (0..self.input_channels())
469            .flat_map(|ici| {
470                geo.patch
471                    .standard_layout_data_field
472                    .iter()
473                    .map(move |x| (x + (ici * c_stride) as isize) * size_of_b)
474            })
475            .collect();
476        let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo.output_shape)?;
477        let packer = mmm.packings()[packing]
478            .1
479            .downcast_ref::<PackedFormat>()
480            .with_context(|| {
481                format_err!(
482                    "Quand Im2Col expects regular packed format, got {:?}",
483                    mmm.packings()[packing].1
484                )
485            })?
486            .clone();
487        let params = LazyIm2colParams { packer, n_byte_offsets, k_byte_offsets };
488        let x = model.wire_node(
489            format!("{name}.lazyIm2col"),
490            LazyIm2Col { params: Arc::new(params) },
491            &[x],
492        )?[0];
493
494        let kernel = self.wire_kernel_as_g_o_ihw(model, name, kernel)?[0];
495        let wire = self.wire_mm_weights_bias(
496            model,
497            name,
498            x,
499            kernel,
500            bias,
501            mmm,
502            packing,
503            c_dt,
504            mmm_output_shape.clone().into(),
505            k,
506            c_axis,
507            h_axis,
508        )?;
509
510        let wire = self.wire_remove_group(model, name, &wire, &mmm_output_shape, c_axis)?;
511        let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
512        Self::wire_geo_reshape(model, name, &wire, &geo.output_shape)
513    }
514
515    #[allow(clippy::type_complexity)]
516    fn compute_geo(
517        &self,
518        input_fact: &TypedFact,
519    ) -> TractResult<(PoolGeometry, usize, usize, TDim)> {
520        let geo = self.pool_spec.compute_geo(&input_fact.shape)?;
521
522        trace!("output channels: {:?}", self.output_channels());
523        let m = self.output_channels() / self.group;
524        let k = self.input_channels() * self.pool_spec.kernel_shape.iter().product::<usize>()
525            / self.group;
526        let n: TDim =
527            self.pool_spec.output_shape(&input_fact.shape)?.hw_dims().iter().cloned().product();
528        Ok((geo, m, k, n))
529    }
530
531    fn choose_impl(
532        &self,
533        input_fact: &TypedFact,
534        weight_fact: &TypedFact,
535        m: usize,
536        k: usize,
537        n: &TDim,
538    ) -> TractResult<(Box<dyn MatMatMul>, usize)> {
539        let w_dt = weight_fact.datum_type;
540        let x_dt = input_fact.datum_type;
541
542        let acc = if x_dt.is_float() { x_dt } else { i32::datum_type() };
543        if w_dt.is_opaque() {
544            let bqf = weight_fact
545                .opaque_fact
546                .as_ref()
547                .and_then(|of| of.downcast_ref::<BlockQuantFact>())
548                .unwrap();
549            let weight_type = WeightType::BlockQuant(bqf.format.clone());
550            tract_linalg::ops()
551                .mmm_impls()
552                .iter()
553                .filter(|mmm| mmm.internal_type() == acc)
554                .flat_map(|mmm| {
555                    mmm.packings().iter().enumerate().map(move |(ix, p)| (mmm, ix, &p.0, &p.1))
556                })
557                .filter(|(_, _, pa, pb)| {
558                    pb.precursor() == x_dt.into() && pa.precursor() == weight_type
559                })
560                .map(|(mmm, p, _, _)| (mmm.clone(), p))
561                .min_by_key(|(mmm, _)| {
562                    mmm.quality().cost() as isize * 1000 - (mmm.mr() * mmm.nr()) as isize
563                })
564                .context("Not matmu found")
565        } else {
566            let mmm = tract_linalg::ops()
567                .mmm(acc, Some(m), Some(k), n.to_usize().ok())
568                .context("No matmul found")?;
569            let packing = mmm
570                .packings()
571                .iter()
572                .position(|p| {
573                    p.0.precursor() == w_dt.unquantized().into()
574                        && p.1.precursor() == x_dt.unquantized().into()
575                })
576                .context("No packing found")?;
577            Ok((mmm, packing))
578        }
579    }
580
581    #[allow(clippy::too_many_arguments)]
582    fn wire_mm_weights_bias(
583        &self,
584        model: &mut TypedModel,
585        name: &str,
586        input: OutletId,
587        g_o_ihw: OutletId,
588        bias: OutletId,
589        mmm: Box<dyn MatMatMul>,
590        packing: usize,
591        c_datum_type: DatumType,
592        mmm_output_shape: ShapeFact,
593        k: usize,
594        c_m_axis: usize,
595        c_n_axis: usize,
596    ) -> TractResult<TVec<OutletId>> {
597        ensure!(model.outlet_fact(bias)?.datum_type == mmm.internal_type());
598        let a_pack = &mmm.packings()[packing].0;
599        let packed_ker = self
600            .wire_pack_g_o_ihw(model, name, &**a_pack, g_o_ihw)
601            .context("in kernel_as_packed_as")?;
602        let (mut c_to_a_axis_mapping, mut c_to_b_axis_mapping) = (tvec!(), tvec!());
603
604        c_to_a_axis_mapping.push((c_m_axis - 1, 0)); // Group
605        c_to_b_axis_mapping.push((0, 0)); // Batch
606        c_to_b_axis_mapping.push((c_m_axis - 1, 1)); // Group
607
608        let geo = AddMatMulGeometry {
609            k: k.to_dim(),
610            c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
611            c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
612        };
613        let mut ops: Vec<ProtoFusedSpec> =
614            vec![ProtoFusedSpec::AddMatMul { geo, a: 1, b: 0, packings: vec![(packing, None)] }];
615        let mut wires: TVec<OutletId> = tvec!(input, packed_ker);
616        let bias_fact = model.outlet_fact(bias)?;
617        if bias_fact.konst.is_none() || !bias_fact.konst.as_ref().unwrap().is_all_zero()? {
618            let (fused, bias) = self.wire_bias_as_non_linear(model, name, bias, c_m_axis - 1)?;
619            wires.push(bias);
620            ops.push(fused);
621        }
622        ops.push(ProtoFusedSpec::Store(vec![unsafe {
623            mmm.c_view(Some(c_m_axis), Some(c_n_axis))
624        }]));
625        model.wire_node(
626            format!("{name}.matmatmul"),
627            OptMatMul::new(
628                vec![mmm],
629                ModePicker::Single,
630                c_datum_type.fact(mmm_output_shape),
631                Some(c_m_axis),
632                Some(c_n_axis),
633                ops,
634                packing == 0 && self.group == 1,
635            )?,
636            &wires,
637        )
638    }
639
640    pub fn wire_as_depth_wise(
641        &self,
642        model: &mut TypedModel,
643        name: &str,
644        wire: &[OutletId],
645    ) -> TractResult<OutletId> {
646        let &[x, kernel, mut bias] = wire else { bail!("Wrong number of inputs") };
647        let x_fact = model.outlet_fact(x)?.clone();
648        let x_shape = x_fact.shape.as_concrete().unwrap();
649        let ConcretePoolGeometry { input_shape, patch, output_shape } =
650            self.pool_spec.compute_geo(&x_fact.shape)?.to_concrete(x_shape)?.into_owned();
651        let kernel = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
652        let c_axis = self.pool_spec.data_format.shape(x_shape)?.c_axis();
653        bias = wire_reshape_bias_for_bin(
654            model,
655            name,
656            bias,
657            x_fact.rank(),
658            c_axis,
659            self.output_channels(),
660        )?[0];
661        let op = DepthWise::new(patch, input_shape, output_shape);
662        Ok(model.wire_node(name, op, &[x, kernel[0], bias])?[0])
663    }
664
665    fn declutter_stride_slice_to_downsample(
666        &self,
667        model: &TypedModel,
668        node: &TypedNode,
669    ) -> TractResult<Option<TypedModelPatch>> {
670        let spatial_rank = self.pool_spec.rank();
671        if let Some(axis) = (0..spatial_rank).find(|&ax| {
672            self.pool_spec.stride(ax) > 1
673                && self.pool_spec.padding.valid_dim(ax, self.pool_spec.stride(ax) == 1)
674                && (self.pool_spec.kernel_shape[ax] == 1
675                    || self.pool_spec.dilation(ax) % self.pool_spec.stride(ax) == 0)
676        }) {
677            let input_fact = model.outlet_fact(node.inputs[0])?;
678            let downsample_factor = self.pool_spec.stride(axis);
679            let mut new_op = self.clone();
680            if new_op.pool_spec.dilation(axis) > 1 {
681                new_op.pool_spec.dilations.as_mut().unwrap()[axis] /= downsample_factor;
682            }
683            new_op.pool_spec.strides.as_mut().unwrap()[axis] /= downsample_factor;
684            let mut patch = TypedModelPatch::default();
685            let mut taps = patch.taps(model, &node.inputs)?;
686            let shape = self.pool_spec.data_format.shape(&input_fact.shape)?;
687            taps[0] = patch.wire_node(
688                format!("{}.downsample.{}", node.name, axis),
689                crate::ops::Downsample::new(axis + shape.h_axis(), downsample_factor as isize, 0),
690                &[taps[0]],
691            )?[0];
692            let id = patch.wire_node(&*node.name, new_op, &taps)?[0];
693            patch.shunt_outside(model, OutletId::new(node.id, 0), id)?;
694            return Ok(Some(patch));
695        }
696        Ok(None)
697    }
698
699    fn declutter_as_einsum(
700        &self,
701        model: &TypedModel,
702        node: &TypedNode,
703    ) -> TractResult<Option<TypedModelPatch>> {
704        let (input_facts, output_facts) = model.node_facts(node.id)?;
705        let full_input_shape = input_facts[0].shape.to_tvec();
706        let input_shape = self.pool_spec.data_format.shape(&full_input_shape)?;
707        if self.group == 1
708            && self.pool_spec.strides().iter().all(|s| *s == 1)
709            && self.pool_spec.dilations().iter().all(|d| *d == 1)
710            && self.pool_spec.kernel_shape.iter().product::<usize>() == 1
711            && self
712                .pool_spec
713                .computed_padding(input_shape.hw_dims())
714                .iter()
715                .all(|pad| pad.pad_after.is_zero() && pad.pad_before.is_zero())
716        {
717            let mut axes = self.axes_mapping(&input_facts, &output_facts)?;
718            let mut patch = TypedModelPatch::new("declutter_as_einsum");
719            let mut taps = patch.taps(model, &node.inputs)?;
720            let name = &node.name;
721            let co = self.output_channels();
722            taps[1] =
723                self.wire_kernel_as_g_o_ihw(&mut patch, &format!("{name}.filters"), taps[1])?[0];
724            taps[1] =
725                patch.wire_node(format!("{name}.filters_as_co_ci"), AxisOp::Rm(0), &[taps[1]])?[0];
726
727            while axes.rank(InOut::In(1)) > 0 {
728                axes = axes.remove_axis_occurency(InOut::In(1), 0)?;
729            }
730            axes = axes
731                .with_extra_axis_occurency('O', InOut::In(1), 0)?
732                .with_extra_axis_occurency('I', InOut::In(1), 1)?;
733
734            let bias_fact = input_facts[2];
735            let wire = if self.q_params.is_some() {
736                if bias_fact.rank() == 1 {
737                    axes = axes.linking('O', (InOut::In(2), 0))?;
738                }
739                let op = EinSum { axes, operating_dt: i32::datum_type(), q_params: self.q_params };
740                patch.wire_node(format!("{name}.einsum"), op, &taps)?[0]
741            } else {
742                axes = axes.remove_slot(InOut::In(2))?;
743                let op = EinSum { axes, operating_dt: input_facts[0].datum_type, q_params: None };
744                let mut wire = patch.wire_node(format!("{name}.einsum"), op, &taps[0..2])?[0];
745
746                if !bias_fact.konst.as_ref().map(|f| f.is_zero()).transpose()?.unwrap_or(false) {
747                    let bias_current_shape =
748                        if bias_fact.rank() == 0 { tvec!() } else { tvec!(co.to_dim()) };
749                    let mut bias_shape = tvec!(1.to_dim(); input_shape.rank());
750                    if bias_fact.rank() > 0 {
751                        bias_shape[input_shape.c_axis()] = co.to_dim();
752                    }
753                    let b = patch.wire_node(
754                        format!("{name}.bias.reshape"),
755                        AxisOp::Reshape(0, bias_current_shape, bias_shape),
756                        &[taps[2]],
757                    )?[0];
758                    wire = patch.wire_node(
759                        format!("{name}.bias"),
760                        crate::ops::math::add(),
761                        &[wire, b],
762                    )?[0];
763                }
764                wire
765            };
766            patch.node_mut(wire.node).name = node.name.to_string();
767            patch.shunt_outside(model, node.id.into(), wire)?;
768            return Ok(Some(patch));
769        }
770        Ok(None)
771    }
772
773    fn declutter_precursor_padding(
774        &self,
775        model: &TypedModel,
776        node: &TypedNode,
777    ) -> TractResult<Option<TypedModelPatch>> {
778        if matches!(self.pool_spec.padding, ExplicitOnnxPool(_, _, _) | SameLower | SameUpper) {
779            return Ok(None);
780        }
781        let prec = model.node(node.inputs[0].node);
782        let pad = if let Some(pad) = prec.op_as::<Pad>() { pad } else { return Ok(None) };
783        let value = if let PadMode::Constant(c) = &pad.mode {
784            c
785        } else {
786            return Ok(None);
787        };
788        let shape = self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
789        if !value.is_zero()?
790            || (self.pool_spec.data_format.has_n() && pad.pads[0] != (0, 0))
791            || pad.pads[shape.c_axis()] != (0, 0)
792        {
793            return Ok(None);
794        }
795        let mut before: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.0).collect();
796        let mut after: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.1).collect();
797        if let Explicit(bef, aft) = &self.pool_spec.padding {
798            izip!(&mut before, bef).for_each(|(pad, cv)| *pad += cv);
799            izip!(&mut after, aft).for_each(|(pad, cv)| *pad += cv);
800        }
801        let padding = Explicit(before, after);
802        let mut new = self.clone();
803        new.pool_spec.padding = padding;
804        let mut patch = TypedModelPatch::default();
805        let mut wire = patch.taps(model, &node.inputs)?;
806        wire[0] = patch.tap_model(model, prec.inputs[0])?;
807        let wire = patch.wire_node(&node.name, new, &wire)?;
808        patch.shunt_outside(model, node.id.into(), wire[0])?;
809        Ok(Some(patch))
810    }
811
812    fn declutter_channel_arithmetic_succ(
813        &self,
814        model: &TypedModel,
815        node: &TypedNode,
816    ) -> TractResult<Option<TypedModelPatch>> {
817        if self.q_params.is_some() || self.group != 1 {
818            return Ok(None);
819        }
820        let &[succ_outlet] = &*node.outputs[0].successors else { return Ok(None) };
821        let succ = model.node(succ_outlet.node);
822        let Some(bin) = succ.op_as::<TypedBinOp>() else { return Ok(None) };
823        let other_input = succ.inputs[1 - succ_outlet.slot];
824        let axes_mapping = model.node_axes_mapping(succ.id)?;
825        let input_shape =
826            self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
827        let conv_c_axis = input_shape.c_axis();
828        if axes_mapping.axis((InOut::In(succ_outlet.slot), conv_c_axis))?.inputs
829            [1 - succ_outlet.slot]
830            .len()
831            != 1
832        {
833            return Ok(None);
834        };
835        let mut other_expected_shape = tvec!(1.to_dim(); input_shape.rank());
836        other_expected_shape[conv_c_axis] = self.output_channels().to_dim();
837        if *other_expected_shape != *model.outlet_fact(other_input)?.shape {
838            return Ok(None);
839        }
840
841        let mut patch = TypedModelPatch::default();
842        let [input, mut kernel, mut bias] = *patch.taps(model, &node.inputs)? else {
843            panic!("Expect three inputs");
844        };
845        let name = &node.name;
846        let succ_name = &succ.name;
847
848        let operand = patch.tap_model(model, other_input)?;
849
850        let renamed_bias = format!("{name}.{succ_name}.bias");
851        let renamed_kernel = format!("{name}.{succ_name}.kernel");
852        bias = wire_reshape_bias_for_bin(
853            &mut patch,
854            format!("{renamed_bias}.reshape"),
855            bias,
856            1,
857            0,
858            self.output_channels(),
859        )?[0];
860
861        let operand = wire_reshape_bias_for_bin(
862            &mut patch,
863            format!("{renamed_bias}.reshape_operand"),
864            operand,
865            1,
866            0,
867            self.output_channels(),
868        )?[0];
869
870        let operand_fact = patch.outlet_fact(operand)?.shape.to_tvec();
871        let kernel_fact = patch.outlet_fact(kernel)?;
872        let mut operand_shape_for_kernel = tvec!(1.to_dim(); 2 + input_shape.hw_rank());
873        operand_shape_for_kernel[self.kernel_fmt.o_axis(&kernel_fact.shape)] =
874            self.output_channels().to_dim();
875        let operand_for_kernel = patch.wire_node(
876            format!("{renamed_kernel}.reshape_operand"),
877            AxisOp::Reshape(0, operand_fact, operand_shape_for_kernel),
878            &[operand],
879        )?[0];
880
881        if bin.0.is::<Sub>() && succ_outlet.slot == 0 {
882            bias = patch.wire_node(&renamed_bias, sub(), &[bias, operand])?[0];
883        } else if bin.0.is::<Sub>() {
884            bias = patch.wire_node(&renamed_bias, sub(), &[operand, bias])?[0];
885        } else if bin.0.is::<Div>() && succ_outlet.slot == 0 {
886            bias = patch.wire_node(&renamed_bias, div(), &[bias, operand])?[0];
887            kernel = patch.wire_node(&renamed_kernel, div(), &[kernel, operand_for_kernel])?[0];
888        } else if bin.0.is::<Div>() {
889            bias = patch.wire_node(&renamed_bias, div(), &[operand, bias])?[0];
890            kernel = patch.wire_node(&renamed_kernel, div(), &[operand_for_kernel, kernel])?[0];
891        } else if bin.0.is::<Add>() {
892            bias = patch.wire_node(&renamed_bias, add(), &[bias, operand])?[0];
893        } else if bin.0.is::<Mul>() {
894            bias = patch.wire_node(&renamed_bias, mul(), &[bias, operand])?[0];
895            kernel = patch.wire_node(&renamed_kernel, mul(), &[kernel, operand_for_kernel])?[0];
896        } else {
897            return Ok(None);
898        };
899        let wire = patch.wire_node(&node.name, self.clone(), &[input, kernel, bias])?[0];
900        patch.shunt_outside(model, succ_outlet.node.into(), wire)?;
901        Ok(Some(patch))
902    }
903}
904
905impl Op for Conv {
906    fn name(&self) -> StaticName {
907        "Conv".into()
908    }
909
910    fn info(&self) -> TractResult<Vec<String>> {
911        let mut info = self.pool_spec.info();
912        info.push(format!("Kernel {:?} (groups:{})", self.kernel_fmt, self.group));
913        Ok(info)
914    }
915
916    fn validation(&self) -> Validation {
917        Validation::Rounding
918    }
919
920    op_as_typed_op!();
921}
922
923impl EvalOp for Conv {
924    fn is_stateless(&self) -> bool {
925        true
926    }
927
928    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
929        let mut model = TypedModel::default();
930        let wire: TVec<OutletId> = inputs
931            .iter()
932            .enumerate()
933            .map(|(ix, v)| model.add_source(format!("source.{ix}"), v.datum_type().fact(v.shape())))
934            .collect::<TractResult<_>>()?;
935        let wire = unsafe {
936            if self.q_params.is_some() {
937                self.wire_as_quant_im2col(&mut model, "im2col-adhoc", &wire)?
938            } else {
939                self.wire_as_im2col_pair(&mut model, "im2col-adhoc", &wire)?
940            }
941        };
942        model.set_output_outlets(&wire)?;
943        model.into_runnable()?.run(inputs)
944    }
945}
946
947impl TypedOp for Conv {
948    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
949        ensure!(self.q_params.is_some() || inputs[0].datum_type.is_float());
950        let q_inputs = if self.q_params.is_some() { 6 } else { 0 };
951        ensure!(inputs[1].datum_type.is_number() || self.kernel_fmt == KernelFormat::OIHW);
952        if inputs.len() != 3 + q_inputs {
953            bail!("Wrong number of inputs: expected {} got {}", 3 + q_inputs, inputs.len());
954        }
955        if self.q_params.is_some() {
956            ensure!(inputs[2].datum_type == i32::datum_type());
957            ensure!(inputs[3].datum_type == i32::datum_type());
958            ensure!(inputs[4].datum_type.is_float());
959            ensure!(inputs[5].datum_type == i32::datum_type());
960            ensure!(inputs[6].datum_type.is_float());
961            ensure!(inputs[7].datum_type == i32::datum_type());
962            ensure!(inputs[8].datum_type.is_float());
963        }
964        let weight_shape = block_quant_aware_weight_shape(inputs[1])?;
965        ensure!(self.pool_spec.rank() + 2 == weight_shape.len());
966        if self.pool_spec.data_format.shape(&*inputs[0].shape)?.c()
967            != &self.input_channels().to_dim()
968        {
969            bail!(
970                    "Inconsistent convolution: input is {:?}, but kernel expects {} input channels.\n{:?}",
971                    inputs[0],
972                    self.input_channels(),
973                    self
974                    );
975        }
976        if let ExplicitOnnxPool(bef, after, _) | Explicit(bef, after) = &self.pool_spec.padding {
977            anyhow::ensure!(bef.len() == self.pool_spec.rank());
978            anyhow::ensure!(after.len() == self.pool_spec.rank());
979        }
980        ensure!(
981            inputs[2].rank() == 0
982            || (inputs[2].rank() == 1
983                && inputs[2].shape.volume() == self.output_channels().to_dim()),
984                "Bias should be scalar or a vector with one value per output channel. Output channels is {}, bias is {:?}",
985                self.output_channels(),
986                inputs[2]
987               );
988        let mut fact = self.pool_spec.output_facts(inputs)?.remove(0);
989        if let Some(dt) = self.q_params {
990            fact.datum_type = dt;
991        } else {
992            ensure!(
993                inputs[1].datum_type.is_opaque() || inputs[0].datum_type == inputs[1].datum_type,
994                "Convolution input, weights and bias must have the same type, got {inputs:?}",
995            )
996        }
997        Ok(tvec!(fact))
998    }
999
1000    fn axes_mapping(
1001        &self,
1002        inputs: &[&TypedFact],
1003        outputs: &[&TypedFact],
1004    ) -> TractResult<AxesMapping> {
1005        let fact = &inputs[0];
1006        let shape = self.pool_spec.data_format.shape(&fact.shape)?;
1007        let mut axes = AxesMapping::disconnected(inputs, outputs)?
1008            .renaming((InOut::In(0), shape.c_axis()), 'I')?
1009            .renaming((InOut::Out(0), shape.c_axis()), 'O')?;
1010        if let Some(n_axis) = shape.n_axis() {
1011            axes = axes
1012                .renaming((InOut::In(0), n_axis), 'N')?
1013                .linking('N', (InOut::Out(0), n_axis))?;
1014        }
1015        let h_axis = shape.h_axis();
1016        let geo = "HWXYZ".chars().chain('a'..);
1017        let kernel_spatial_shape = &self.pool_spec.kernel_shape;
1018        let padding = self.pool_spec.computed_padding(shape.hw_dims());
1019        for ((ix, &dim), repr) in kernel_spatial_shape.iter().enumerate().zip(geo) {
1020            if dim == 1
1021                && self.pool_spec.dilation(ix) == 1
1022                && self.pool_spec.stride(ix) == 1
1023                && padding[ix].pad_before.is_zero()
1024                && padding[ix].pad_after.is_zero()
1025            {
1026                axes = axes
1027                    .renaming((InOut::In(0), ix + h_axis), repr)?
1028                    .linking(repr, (InOut::Out(0), ix + h_axis))?;
1029            }
1030        }
1031        if self.q_params.is_some() {
1032            for (qp_ix, qp) in inputs.iter().enumerate().skip(3) {
1033                if qp.rank() == 1 {
1034                    axes = match qp_ix {
1035                        3 | 4 => axes.linking('I', (InOut::In(qp_ix), 0))?,
1036                        5 | 6 => axes.linking('O', (InOut::In(qp_ix), 0))?,
1037                        7 | 8 => axes.linking('O', (InOut::In(qp_ix), 0))?,
1038                        _ => unreachable!(),
1039                    };
1040                }
1041            }
1042        }
1043        Ok(axes)
1044    }
1045
1046    fn declutter(
1047        &self,
1048        model: &TypedModel,
1049        node: &TypedNode,
1050    ) -> TractResult<Option<TypedModelPatch>> {
1051        macro_rules! pass {
1052            ($func:ident) => {
1053                if let Some(mut r) = self.$func(model, node).context(stringify!($func))? {
1054                    trace!(stringify!($func));
1055                    r.push_context(stringify!($func));
1056                    return Ok(Some(r));
1057                }
1058            };
1059        }
1060        pass!(declutter_stride_slice_to_downsample);
1061        pass!(declutter_as_einsum);
1062        pass!(declutter_channel_arithmetic_succ);
1063        pass!(declutter_precursor_padding);
1064        Ok(None)
1065    }
1066
1067    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
1068        let shape = self.pool_spec.data_format.shape(inputs[0].shape.to_tvec())?;
1069        let kernel_spatial_shape = &self.pool_spec.kernel_shape;
1070        let output_dims = self.pool_spec.padding.compute(
1071            shape.hw_dims(),
1072            kernel_spatial_shape,
1073            &self
1074                .pool_spec
1075                .dilations
1076                .clone()
1077                .unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
1078            &self.pool_spec.strides.clone().unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
1079        );
1080        let n_output_points: TDim =
1081            output_dims.iter().map(|d| d.convoluted.clone()).product::<TDim>();
1082        let n_output_channels = self.output_channels().to_dim();
1083        let kernel_surface = kernel_spatial_shape.iter().product::<usize>().to_dim();
1084        let one = 1.to_dim();
1085        Ok(tvec!((
1086            Cost::FMA(inputs[0].datum_type),
1087            shape.n().cloned().unwrap_or(one)
1088                * shape.c()
1089                * n_output_channels
1090                * n_output_points
1091                * kernel_surface
1092                / self.group
1093        )))
1094    }
1095
1096    fn change_axes(
1097        &self,
1098        model: &TypedModel,
1099        node: &TypedNode,
1100        io: InOut,
1101        change: &AxisOp,
1102    ) -> TractResult<Option<AxisChangeConsequence>> {
1103        if io == InOut::In(1) {
1104            return Ok(None);
1105        }
1106        if io == InOut::In(2) {
1107            if let &AxisOp::Rm(_) = change {
1108                return Ok(Some(AxisChangeConsequence {
1109                    substitute_op: Some(Box::new(self.clone())),
1110                    wire_changes: tvec!(),
1111                }));
1112            }
1113        }
1114        let full_input_shape = model.outlet_fact(node.inputs[0])?.shape.to_tvec();
1115        let shape = self.pool_spec.data_format.shape(full_input_shape.clone())?;
1116        // remove n
1117        if let Some(n) = shape.n_axis() {
1118            assert_eq!(n, 0);
1119            if change == &AxisOp::Rm(n) {
1120                let op = Conv { pool_spec: self.pool_spec.dispose_n_axis(), ..self.clone() };
1121                return Ok(Some(AxisChangeConsequence {
1122                    substitute_op: Some(Box::new(op)),
1123                    wire_changes: tvec!(
1124                        (InOut::In(0), change.clone()),
1125                        (InOut::Out(0), change.clone())
1126                    ),
1127                }));
1128            }
1129            if change.transform_axis(n).map(|axis| axis > 0).unwrap_or(true) {
1130                return Ok(None);
1131            }
1132        }
1133        // format swap: chw <-> hwc
1134        let (new_format, axis_move) = match self.pool_spec.data_format {
1135            DataFormat::NCHW => {
1136                (DataFormat::NHWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
1137            }
1138            DataFormat::CHW => {
1139                (DataFormat::HWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
1140            }
1141            DataFormat::NHWC => (DataFormat::NCHW, AxisOp::Move(shape.c_axis(), 1)),
1142            DataFormat::HWC => (DataFormat::CHW, AxisOp::Move(shape.c_axis(), 0)),
1143        };
1144        if *change == axis_move {
1145            let mut new_op = self.clone();
1146            new_op.pool_spec.data_format = new_format;
1147            return Ok(Some(AxisChangeConsequence {
1148                substitute_op: Some(Box::new(new_op)),
1149                wire_changes: tvec!(
1150                    (InOut::In(0), change.clone()),
1151                    (InOut::Out(0), change.clone())
1152                ),
1153            }));
1154        }
1155        // geo axis manips
1156        if model.node_input_facts(node.id)?[1].datum_type.is_opaque() {
1157            return Ok(None);
1158        }
1159        use AxisOp::*;
1160        let h_axis = shape.h_axis();
1161        let hw_axes = shape.hw_axes();
1162        let kh_axis = self.kernel_fmt.h_axis();
1163        let (geo_adjusted, kernel_adjusted) = match change {
1164            Rm(a)
1165                if hw_axes.contains(a)
1166                    && hw_axes.len() > 1
1167                    && self.pool_spec.dilation(a - h_axis) == 1
1168                    && self.pool_spec.stride(a - h_axis) == 1
1169                    && self.pool_spec.kernel_shape[a - h_axis] == 1 =>
1170            {
1171                let geo_axis = a - h_axis;
1172                (Rm(geo_axis), Rm(kh_axis + geo_axis))
1173            }
1174            Add(a) if hw_axes.contains(a) => (Add(a - h_axis), Add(a - h_axis + kh_axis)),
1175            Move(f, t) if hw_axes.contains(f) && hw_axes.contains(t) => {
1176                (Move(f - h_axis, t - h_axis), Move(f - h_axis + kh_axis, t - h_axis + kh_axis))
1177            }
1178            _ => return Ok(None),
1179        };
1180        let pool_spec = self.pool_spec.change_geo_axes(&geo_adjusted)?;
1181        let new_op = Conv { pool_spec, ..self.clone() };
1182        Ok(Some(AxisChangeConsequence {
1183            substitute_op: Some(Box::new(new_op)),
1184            wire_changes: tvec!(
1185                (InOut::In(0), change.clone()),
1186                (InOut::In(1), kernel_adjusted),
1187                (InOut::Out(0), change.clone())
1188            ),
1189        }))
1190    }
1191
1192    fn codegen(
1193        &self,
1194        model: &TypedModel,
1195        node: &TypedNode,
1196    ) -> TractResult<Option<TypedModelPatch>> {
1197        let input_fact = model.outlet_fact(node.inputs[0])?;
1198        unsafe {
1199            if self.q_params.is_some() {
1200                let mut patch = TypedModelPatch::default();
1201                let inputs = patch.taps(model, &node.inputs)?;
1202                let wire = self
1203                    .wire_as_quant_im2col(&mut patch, &node.name, &inputs)
1204                    .context("in wire_as_quant_im2col")?;
1205                patch.shunt_outside(model, node.id.into(), wire[0])?;
1206                patch.obliterate(node.id)?;
1207                Ok(Some(patch.with_context("quantized-codegen")))
1208            } else if input_fact
1209                .shape
1210                .as_concrete()
1211                .map(|s| {
1212                    should_use_lazy(
1213                        &self.pool_spec.data_format.shape(s.into()).unwrap(),
1214                        &self.pool_spec,
1215                        self.group,
1216                    )
1217                })
1218                .unwrap_or(false)
1219            {
1220                let mut patch = TypedModelPatch::new("wire_as_lazy_im2col");
1221                let inputs = patch.taps(model, &node.inputs)?;
1222                let wire = self
1223                    .wire_as_lazy_im2col(&mut patch, &node.name, &inputs)
1224                    .context("wire_as_lazy_im2col")?[0];
1225                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1226                patch.obliterate(node.id)?;
1227                Ok(Some(patch))
1228            } else if self.group != 1
1229                && self.group == self.output_channels()
1230                && self.group == self.input_channels()
1231                && input_fact.shape.as_concrete().is_some()
1232            {
1233                let mut patch = TypedModelPatch::default();
1234                let inputs = patch.taps(model, &node.inputs)?;
1235                let wire = self
1236                    .wire_as_depth_wise(&mut patch, &node.name, &inputs)
1237                    .context("wire_as_depth_wise")?;
1238                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1239                patch.obliterate(node.id)?;
1240                Ok(Some(patch))
1241            } else {
1242                let mut patch = TypedModelPatch::default();
1243                let inputs = patch.taps(model, &node.inputs)?;
1244                let wire = self
1245                    .wire_as_im2col_pair(&mut patch, &node.name, &inputs)
1246                    .context("in wire_as_im2col_pair")?[0];
1247                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1248                patch.obliterate(node.id)?;
1249                Ok(Some(patch))
1250            }
1251        }
1252    }
1253
1254    as_op!();
1255}
1256
1257fn should_use_lazy(input_shape: &DataShape, pool_spec: &PoolSpec, group: usize) -> bool {
1258    input_shape.n().unwrap_or(&1) == &1
1259        && group == 1
1260        && pool_spec.kernel_shape.iter().product::<usize>() > 5
1261}
1262
1263#[allow(non_snake_case)]
1264#[cfg(test)]
1265mod test {
1266    use super::*;
1267    use crate::ops::array::Pad;
1268    use DataFormat::*;
1269
1270    #[test]
1271    fn onnx_basic_convinteger() {
1272        let op = Conv {
1273            pool_spec: PoolSpec {
1274                data_format: NCHW,
1275                kernel_shape: tvec!(2, 2),
1276                padding: Valid,
1277                dilations: None,
1278                strides: None,
1279                input_channels: 1,
1280                output_channels: 1,
1281            },
1282            kernel_fmt: KernelFormat::OIHW,
1283            group: 1,
1284            q_params: Some(i32::datum_type()),
1285        };
1286        let input = tvec!(
1287            rctensor4(&[[[[1u8, 2, 3], [4, 5, 6], [7, 8, 9]]]]),
1288            rctensor4(&[[[[1u8, 1], [1, 1]]]]),
1289            rctensor0(0u32),
1290            rctensor0(1u8),
1291            rctensor0(1.0f32),
1292            rctensor0(0u8),
1293            rctensor0(1.0f32),
1294            rctensor0(0i32),
1295            rctensor0(1.0f32),
1296        );
1297        let input = input.into_iter().map(IntoTValue::into_tvalue).collect::<TVec<_>>();
1298        let output = op.eval(input).unwrap();
1299        assert_eq!(*output[0], tensor4(&[[[[8i32, 12], [20, 24]]]]));
1300    }
1301
1302    #[test]
1303    fn valid_conv_absorbs_precursor_pad() -> TractResult<()> {
1304        let mut model = TypedModel::default();
1305        let wire = tvec!(model.add_source("source", f32::fact(dims!(1, 10)))?);
1306        let wire = model.wire_node(
1307            "pad",
1308            Pad {
1309                pads: vec![(0, 0), (1, 0)],
1310                mode: ops::array::PadMode::Constant(rctensor0(0f32)),
1311            },
1312            &wire,
1313        )?;
1314        let kernel = model.add_const("kernel", rctensor3(&[[[1f32, 2f32]]]))?;
1315        let bias = model.add_const("bias", rctensor0(0f32))?;
1316        let wire = model.wire_node(
1317            "conv",
1318            Conv {
1319                pool_spec: PoolSpec {
1320                    data_format: crate::ops::nn::DataFormat::CHW,
1321                    dilations: None,
1322                    strides: None,
1323                    kernel_shape: tvec![2],
1324                    padding: Explicit(tvec![0], tvec![0]),
1325                    input_channels: 1,
1326                    output_channels: 1,
1327                },
1328                kernel_fmt: crate::ops::cnn::KernelFormat::OIHW,
1329                group: 1,
1330                q_params: None,
1331            },
1332            &[wire[0], kernel, bias],
1333        )?;
1334        model.set_output_outlets(&wire)?;
1335        model.declutter()?;
1336        assert_eq!(model.nodes().len(), 4); // source + conv + kernel + bias
1337        let cv = model.nodes()[3].op_as::<Conv>().unwrap();
1338        assert_eq!(cv.pool_spec.padding, Explicit(tvec![1], tvec![0])); // source + conv
1339        Ok(())
1340    }
1341}