Skip to main content

tract_core/ops/cnn/conv/
conv.rs

1use tract_data::itertools::izip;
2use tract_linalg::WeightType;
3use tract_linalg::block_quant::{BlockQuantFact, PackedBlockQuantFormat};
4use tract_num_traits::Zero;
5
6use crate::internal::*;
7use crate::model::*;
8use crate::ops;
9use crate::ops::array::Pad;
10use crate::ops::array::PadMode;
11use crate::ops::binary::TypedBinOp;
12use crate::ops::cast::cast;
13use crate::ops::cnn::PaddingSpec::*;
14use crate::ops::cnn::conv::block_quant::{BlockQuantIntoShape, SplitGroupBlockQuant};
15use crate::ops::cnn::conv::lazy_im2col::LazyIm2Col;
16use crate::ops::cnn::conv::lazy_im2col::LazyIm2colParams;
17use crate::ops::cnn::wire_reshape_bias_for_bin;
18use crate::ops::einsum::EinSum;
19use crate::ops::math::{Add, Div, Mul, Sub};
20use crate::ops::math::{add, div, mul, sub};
21use crate::ops::matmul::ModePicker;
22use crate::ops::matmul::optimized::AddMatMulGeometry;
23use crate::ops::matmul::optimized::MapOutputAxisToInput;
24use crate::ops::matmul::pack::{OptMatMulPack, OptSimpleMatMulPack};
25use crate::ops::matmul::quant::wire_ensure_q8_flavour;
26use crate::ops::nn::Reduce;
27
28use super::depth_wise::DepthWise;
29use super::im2col::Im2Col;
30use crate::ops::cnn::conv::KernelFormat;
31use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry, PoolSpec};
32use crate::ops::matmul::optimized::{OptMatMul, ProtoFusedSpec};
33use crate::ops::nn::{BaseDataShape, DataFormat, DataShape};
34
35use tract_linalg::mmm::{MMMInputFormat, MatMatMul};
36use tract_linalg::pack::{PackedFormat, PackedI8K4};
37
38#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
39pub struct Conv {
40    pub pool_spec: PoolSpec,
41    pub kernel_fmt: KernelFormat,
42    pub group: usize,
43    // None -> floats
44    // Some(I32) -> output is I32 (use quantized kernels, but output will be i32). last 2 Q inputs
45    // are ignored
46    // Some(QXX) -> quantized XX, but parameters are ignored (I8, U8, or I32) in favor of last 2 Q inputs
47    pub q_params: Option<DatumType>,
48}
49
50impl Conv {
51    pub fn input_channels(&self) -> usize {
52        self.pool_spec.input_channels
53    }
54
55    pub fn output_channels(&self) -> usize {
56        self.pool_spec.output_channels
57    }
58
59    pub fn wire_kernel_as_g_o_ihw(
60        &self,
61        model: &mut TypedModel,
62        name: &str,
63        mut kernel: OutletId,
64    ) -> TractResult<TVec<OutletId>> {
65        let fact = model.outlet_fact(kernel)?;
66        if fact.is_exotic() {
67            ensure!(self.kernel_fmt == KernelFormat::OIHW && fact.rank() >= 2);
68            kernel = model.wire_node(
69                format!("{name}.prep_kernel.g"),
70                SplitGroupBlockQuant { group: self.group },
71                &[kernel],
72            )?[0];
73            kernel = model.wire_node(
74                format!("{name}.prep_kernel.ihw"),
75                BlockQuantIntoShape {
76                    shape: tvec!(
77                        self.output_channels() / self.group,
78                        self.input_channels() / self.group
79                            * self.pool_spec.kernel_shape.iter().product::<usize>(),
80                    ),
81                },
82                &[kernel],
83            )?[0];
84            Ok(tvec!(kernel))
85        } else {
86            for (ix, op) in self
87                .kernel_fmt
88                .kernel_as_group_o_ihw_ops(&fact.shape, self.group)
89                .into_iter()
90                .enumerate()
91            {
92                kernel = model.wire_node(format!("{name}.prep_kernel.{ix}"), op, &[kernel])?[0];
93            }
94            Ok(tvec!(kernel))
95        }
96    }
97
98    fn wire_pack_g_o_ihw(
99        &self,
100        model: &mut TypedModel,
101        name: &str,
102        format: &dyn MMMInputFormat,
103        kernel: OutletId,
104    ) -> TractResult<OutletId> {
105        let fact = model.outlet_fact(kernel)?;
106        let wire = if fact.is_exotic() {
107            let fact = model
108                .outlet_fact(kernel)?
109                .exotic_fact
110                .as_ref()
111                .and_then(|of| of.downcast_ref::<BlockQuantFact>())
112                .context("Only manage BlockQuant")?;
113            model.wire_node(
114                format!("{name}.prep_kernel.pack"),
115                OptSimpleMatMulPack {
116                    packed_format: format
117                        .downcast_ref::<PackedBlockQuantFormat>()
118                        .context("Expect a block quant format")?
119                        .clone(),
120                    k: fact.k(),
121                    m: fact.m(),
122                },
123                &[kernel],
124            )?
125        } else {
126            // PackedFormat or a custom numeric packer (e.g. PackedI8K4).
127            model.wire_node(
128                format!("{name}.prep_kernel.pack"),
129                OptMatMulPack {
130                    packers: vec![dyn_clone::clone_box(format)],
131                    k_axis: 2,
132                    mn_axis: 1,
133                    mode_picker: ModePicker::Single,
134                },
135                &[kernel],
136            )?
137        };
138        Ok(wire[0])
139    }
140
141    // group,bias
142    fn wire_bias_as_non_linear(
143        &self,
144        model: &mut TypedModel,
145        name: &str,
146        bias: OutletId,
147        c_group_axis: usize,
148    ) -> TractResult<(ProtoFusedSpec, OutletId)> {
149        use tract_linalg::BinOp::Add;
150        let fact = model.outlet_fact(bias)?;
151        if fact.shape.volume().is_one() {
152            Ok((ProtoFusedSpec::BinScalar(2, Add), bias))
153        } else {
154            let bias = AxisOp::wire_split_axis(
155                model,
156                format!("{name}.reformat_bias"),
157                bias,
158                0,
159                self.group,
160            )?[0];
161            let pfs =
162                ProtoFusedSpec::BinPerRow(2, Add, MapOutputAxisToInput(tvec!((c_group_axis, 0))));
163            Ok((pfs, bias))
164        }
165    }
166
167    pub unsafe fn wire_as_quant_im2col(
168        &self,
169        model: &mut TypedModel,
170        name: &str,
171        wires: &[OutletId],
172    ) -> TractResult<TVec<OutletId>> {
173        ensure!(self.q_params.is_some());
174        use crate::ops::matmul::quant as qmm;
175
176        let c_dt = self.q_params.unwrap();
177        let &[mut x, mut kernel, bias, mut x0, x_scale, mut k0, mut k_scale, y0, y_scale] = wires
178        else {
179            bail!("Wrong number of inputs")
180        };
181        wire_ensure_q8_flavour(model, name, &mut kernel, "k", &mut k0, i8::datum_type())?;
182        wire_ensure_q8_flavour(model, name, &mut x, "x", &mut x0, i8::datum_type())?;
183
184        let a_fact = model.outlet_fact(kernel)?.clone();
185        let b_fact = model.outlet_fact(x)?.clone();
186
187        let (_geo, m, k, n) = self.compute_geo(&b_fact)?;
188        let (mmm, packing) = self.choose_impl(&b_fact, &a_fact, m, k, &n)?;
189        let output_shape = self.pool_spec.output_shape(&b_fact.shape)?;
190
191        if !model.outlet_fact(k_scale)?.shape.volume().is_one() {
192            // requant is performed before geo_reshape, so we need at most one geo axis to the
193            // right
194            if !output_shape.fmt.c_is_last() {
195                k_scale = model.wire_node(
196                    format!("{name}.a_scale_axis_fix"),
197                    AxisOp::Add(1),
198                    &[k_scale],
199                )?[0];
200            }
201        }
202
203        let abc_scale = qmm::combine_scales(model, name, k_scale, x_scale, y_scale)?;
204
205        let im2col = model.wire_node(
206            format!("{name}.im2col"),
207            Im2Col::new(
208                self.pool_spec.clone(),
209                self.group,
210                k,
211                &b_fact.shape,
212                mmm.clone(),
213                packing,
214            )?,
215            &[x, x0],
216        )?[0];
217
218        let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
219        let g_o_ihw_as_i32 =
220            model.wire_node(format!("{name}.kernel_as_i32"), cast(i32::datum_type()), &g_o_ihw)?;
221        let sum_ker_g_c_k = model.wire_node(
222            format!("{name}.sum_ker_g_c_k"),
223            Reduce::new(tvec!(2), ops::nn::Reducer::Sum),
224            &g_o_ihw_as_i32,
225        )?;
226        let sum_ker_a_g_c =
227            model.wire_node(format!("{name}.rm_k"), AxisOp::Rm(2), &sum_ker_g_c_k)?;
228        // align sum_A from G,C to "C" shape: N,HW,G,C (or N,G,C,HW)
229        let sum_ker_n_g_c = model.wire_node(
230            format!("{name}.sum_ker_n_g_c.axis_0"),
231            AxisOp::Add(0),
232            &sum_ker_a_g_c,
233        )?;
234        let hw_position = if self.pool_spec.data_format.c_is_last() { 1 } else { 3 };
235        let sum_ker = model.wire_node(
236            format!("{name}.sum_ker_n_g_c"),
237            AxisOp::Add(hw_position),
238            &sum_ker_n_g_c,
239        )?;
240
241        ensure!(
242            mmm.packings()[packing].1.downcast_ref::<PackedFormat>().is_some()
243                || mmm.packings()[packing].1.downcast_ref::<PackedI8K4>().is_some(),
244            "Im2Col/QSumB support PackedFormat or PackedI8K4 activation packings"
245        );
246        let mut sum_x = model.wire_node(
247            format!("{name}.sum_x"),
248            super::QSumB { dt: b_fact.datum_type, n, r: mmm.nr(), k },
249            &[im2col],
250        )?;
251        // sum_b is N,G,HW. make it N,HW,G,C or N,G,C,HW
252        sum_x = model.wire_node(format!("{name}.add_c"), AxisOp::Add(2), &sum_x)?;
253        if self.pool_spec.data_format.c_is_last() {
254            sum_x =
255                model.wire_node(format!("{name}.transpose_sum_b"), AxisOp::Move(3, 1), &sum_x)?;
256        }
257
258        let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?;
259        let bias_name = &model.node(bias.node).name;
260        let bias =
261            model.wire_node(format!("{bias_name}.cast"), cast(mmm.internal_type()), &[bias])?[0];
262        let wire = self.wire_mm_weights_bias(
263            model,
264            name,
265            im2col,
266            g_o_ihw[0],
267            bias,
268            mmm,
269            packing,
270            i32::datum_type(),
271            mmm_output_shape.clone().into(),
272            k,
273            c_axis,
274            h_axis,
275        )?;
276
277        let wire = qmm::compensate_zero_points(
278            model,
279            name,
280            wire[0],
281            k.to_dim(),
282            k0,
283            x0,
284            sum_ker[0],
285            sum_x[0],
286        )?;
287
288        let wire = self.wire_remove_group(model, name, &[wire], &mmm_output_shape, c_axis)?;
289        let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
290        let wire = qmm::requant(model, name, wire[0], c_dt, abc_scale, y0)?;
291        Self::wire_geo_reshape(model, name, &[wire], &output_shape)
292    }
293
294    pub fn wire_remove_group<D: DimLike>(
295        &self,
296        model: &mut TypedModel,
297        name: &str,
298        wire: &[OutletId],
299        mmm_output_shape: &[D],
300        c_axis: usize,
301    ) -> TractResult<TVec<OutletId>> {
302        let m = &mmm_output_shape[c_axis];
303        let op = if self.group == 1 {
304            AxisOp::Rm(c_axis - 1)
305        } else {
306            AxisOp::Reshape(
307                c_axis - 1,
308                tvec!(self.group.to_dim(), m.to_dim()),
309                tvec!(m.to_dim() * self.group),
310            )
311        };
312        model.wire_node(format!("{name}.reshape_group"), op, wire)
313    }
314
315    pub unsafe fn wire_as_im2col_pair(
316        &self,
317        model: &mut TypedModel,
318        name: &str,
319        wire: &[OutletId],
320    ) -> TractResult<TVec<OutletId>> {
321        let &[x, w, bias] = wire else { bail!("Wrong number of inputs") };
322        let x_fact = model.outlet_fact(x)?.clone();
323        let w_fact = model.outlet_fact(w)?.clone();
324        let c_dt = crate::ops::matmul::output_type(x_fact.datum_type);
325
326        let (_, m, k, n) = self.compute_geo(&x_fact)?;
327        let (mmm, packing) = self.choose_impl(&x_fact, &w_fact, m, k, &n)?;
328        let geo_output_shape = self.pool_spec.output_shape(&x_fact.shape)?;
329        let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo_output_shape)?;
330
331        let padding =
332            model.add_const(format!("{name}.b0"), Tensor::zero_scalar_dt(x_fact.datum_type)?)?;
333
334        let mut wire: TVec<_> = wire.into();
335        wire[0] = model.wire_node(
336            format!("{name}.im2col"),
337            Im2Col::new(
338                self.pool_spec.clone(),
339                self.group,
340                k,
341                &x_fact.shape,
342                mmm.clone(),
343                packing,
344            )?,
345            &[wire[0], padding],
346        )?[0];
347
348        let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, wire[1])?;
349
350        let wire = self
351            .wire_mm_weights_bias(
352                model,
353                name,
354                wire[0],
355                g_o_ihw[0],
356                bias,
357                mmm,
358                packing,
359                c_dt,
360                mmm_output_shape.clone().into(),
361                k.to_usize().unwrap(),
362                c_axis,
363                h_axis,
364            )
365            .context("in wire_opt_matmul")?;
366
367        let wire = self.wire_remove_group(model, name, &wire, &mmm_output_shape, c_axis)?;
368        let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
369        Self::wire_geo_reshape(model, name, &wire, &geo_output_shape)
370    }
371
372    // always have N and G. G is right before C, c_axis point to C, c_axis-1 points to G
373    fn mmm_output_shape<D: DimLike>(
374        &self,
375        output_shape: &BaseDataShape<D, TVec<D>>,
376    ) -> TractResult<(TVec<D>, usize, usize)> {
377        let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
378        let shape: BaseDataShape<D, TVec<D>> = output_shape.fmt.with_n().from_n_c_hw(
379            output_shape.n().cloned().unwrap_or_else(|| 1.into()),
380            output_shape.c().clone(),
381            tvec!(geo_collapsed_out),
382        )?;
383        let mut mmm_output_shape: TVec<D> = shape.shape.clone();
384        let mut c_axis = shape.c_axis();
385        let mut h_axis = shape.h_axis();
386        mmm_output_shape[shape.c_axis()] = mmm_output_shape[c_axis].clone() / self.group;
387        mmm_output_shape.insert(c_axis, self.group.into());
388        if h_axis > c_axis {
389            h_axis += 1;
390        }
391        c_axis += 1;
392        Ok((mmm_output_shape, c_axis, h_axis))
393    }
394
395    fn wire_rm_n_if_needed(
396        &self,
397        model: &mut TypedModel,
398        name: &str,
399        wire: &[OutletId],
400    ) -> TractResult<TVec<OutletId>> {
401        if self.pool_spec.data_format.has_n() {
402            Ok(wire.into())
403        } else {
404            model.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), wire)
405        }
406    }
407
408    fn wire_geo_reshape<D: DimLike>(
409        model: &mut TypedModel,
410        name: &str,
411        wire: &[OutletId],
412        output_shape: &BaseDataShape<D, TVec<D>>,
413    ) -> TractResult<TVec<OutletId>> {
414        let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
415        model
416            .wire_node(
417                name,
418                AxisOp::Reshape(
419                    output_shape.h_axis(),
420                    tvec!(geo_collapsed_out.to_dim()),
421                    output_shape.hw_dims().iter().map(|d| d.to_dim()).collect(),
422                ),
423                wire,
424            )
425            .context("in wire_geo_reshape")
426    }
427
428    pub unsafe fn wire_as_lazy_im2col(
429        &self,
430        model: &mut TypedModel,
431        name: &str,
432        wire: &[OutletId],
433    ) -> TractResult<TVec<OutletId>> {
434        let &[mut x, kernel, bias] = wire else { bail!("Wrong number of inputs") };
435        let mut x_fact = model.outlet_fact(x)?.clone();
436        let w_fact = model.outlet_fact(kernel)?.clone();
437        let (geo, m, k, n) = self.compute_geo(&x_fact)?;
438        let (mmm, packing) = self.choose_impl(&x_fact, &w_fact, m, k, &n)?;
439        debug!("{name} as lazy_im2col: m={m} k={k} n={n} {mmm:?}");
440        let input_shape = x_fact.shape.as_concrete().unwrap().to_vec();
441        let mut geo = geo.to_concrete(&input_shape)?.into_owned();
442        let mut input_shape: DataShape = self.pool_spec.data_format.shape(input_shape.into())?;
443        let padding = self.pool_spec.computed_padding(input_shape.hw_dims());
444        if padding.iter().any(|axis| axis.pad_before != 0 || axis.pad_after != 0) {
445            let mut pads = vec![(0, 0); x_fact.rank()];
446            for (ix, ax) in padding.iter().enumerate() {
447                pads[input_shape.h_axis() + ix] = (ax.pad_before, ax.pad_after);
448            }
449            let op = crate::ops::array::Pad {
450                mode: crate::ops::array::PadMode::Constant(
451                    Tensor::zero_scalar_dt(x_fact.datum_type)?.into_arc_tensor(),
452                ),
453                pads,
454            };
455            x = model.wire_node(format!("{name}.pad"), op, &[x])?[0];
456            let valid_pool_spec = PoolSpec { padding: Valid, ..self.pool_spec.clone() };
457            x_fact = model.outlet_fact(x)?.clone();
458            let concrete_shape = x_fact.shape.as_concrete().unwrap();
459            input_shape = valid_pool_spec.data_format.shape(concrete_shape.into())?;
460            geo = valid_pool_spec
461                .compute_geo(&x_fact.shape)?
462                .to_concrete(concrete_shape)?
463                .into_owned();
464        }
465        let c_dt = crate::ops::matmul::output_type(x_fact.datum_type);
466        let c_stride = input_shape.c_stride();
467        let size_of_b = x_fact.datum_type.size_of() as isize;
468        let n_byte_offsets: Vec<isize> =
469            geo.patch.centers_offsets().into_iter().map(|x| x * size_of_b).collect();
470        // For grouped convs, k offsets cover one group's input slice (ci_per_group channels);
471        // each group reads from a different base offset (group_stride_bytes apart).
472        let ci_per_group = self.input_channels() / self.group;
473        let k_byte_offsets: Vec<isize> = (0..ci_per_group)
474            .flat_map(|ici| {
475                geo.patch
476                    .standard_layout_data_field
477                    .iter()
478                    .map(move |x| (x + (ici * c_stride) as isize) * size_of_b)
479            })
480            .collect();
481        let group_stride_bytes = (ci_per_group * c_stride) as isize * size_of_b;
482        let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo.output_shape)?;
483        let packer = mmm.packings()[packing]
484            .1
485            .downcast_ref::<PackedFormat>()
486            .with_context(|| {
487                format_err!(
488                    "Quand Im2Col expects regular packed format, got {:?}",
489                    mmm.packings()[packing].1
490                )
491            })?
492            .clone();
493        let params = LazyIm2colParams { packer, n_byte_offsets, k_byte_offsets };
494        let x = model.wire_node(
495            format!("{name}.lazyIm2col"),
496            LazyIm2Col { params: Arc::new(params), group: self.group, group_stride_bytes },
497            &[x],
498        )?[0];
499
500        let kernel = self.wire_kernel_as_g_o_ihw(model, name, kernel)?[0];
501        let wire = self.wire_mm_weights_bias(
502            model,
503            name,
504            x,
505            kernel,
506            bias,
507            mmm,
508            packing,
509            c_dt,
510            mmm_output_shape.clone().into(),
511            k,
512            c_axis,
513            h_axis,
514        )?;
515
516        let wire = self.wire_remove_group(model, name, &wire, &mmm_output_shape, c_axis)?;
517        let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
518        Self::wire_geo_reshape(model, name, &wire, &geo.output_shape)
519    }
520
521    #[allow(clippy::type_complexity)]
522    fn compute_geo(
523        &self,
524        input_fact: &TypedFact,
525    ) -> TractResult<(PoolGeometry, usize, usize, TDim)> {
526        let geo = self.pool_spec.compute_geo(&input_fact.shape)?;
527
528        trace!("output channels: {:?}", self.output_channels());
529        let m = self.output_channels() / self.group;
530        let k = self.input_channels() * self.pool_spec.kernel_shape.iter().product::<usize>()
531            / self.group;
532        let n: TDim =
533            self.pool_spec.output_shape(&input_fact.shape)?.hw_dims().iter().cloned().product();
534        Ok((geo, m, k, n))
535    }
536
537    fn choose_impl(
538        &self,
539        input_fact: &TypedFact,
540        weight_fact: &TypedFact,
541        m: usize,
542        k: usize,
543        n: &TDim,
544    ) -> TractResult<(Box<dyn MatMatMul>, usize)> {
545        let w_dt = weight_fact.datum_type;
546        let x_dt = input_fact.datum_type;
547
548        let acc = if x_dt.is_float() { x_dt } else { i32::datum_type() };
549        if weight_fact.is_exotic() {
550            let bqf = weight_fact
551                .exotic_fact
552                .as_ref()
553                .and_then(|of| of.downcast_ref::<BlockQuantFact>())
554                .unwrap();
555            let weight_type = WeightType::BlockQuant(bqf.format.clone());
556            tract_linalg::ops()
557                .mmm_impls()
558                .iter()
559                .filter(|mmm| mmm.internal_type() == acc)
560                .flat_map(|mmm| {
561                    mmm.packings().iter().enumerate().map(move |(ix, p)| (mmm, ix, &p.0, &p.1))
562                })
563                .filter(|(_, _, pa, pb)| {
564                    pb.precursor() == x_dt.into() && pa.precursor() == weight_type
565                })
566                .map(|(mmm, p, _, _)| (mmm.clone(), p))
567                .min_by_key(|(mmm, _)| {
568                    mmm.quality().cost() as isize * 1000 - (mmm.mr() * mmm.nr()) as isize
569                })
570                .context("Not matmu found")
571        } else {
572            let mmm = tract_linalg::ops()
573                .mmm(acc, Some(m), Some(k), n.to_usize().ok())
574                .context("No matmul found")?;
575            let packing = mmm
576                .packings()
577                .iter()
578                .position(|p| {
579                    p.0.precursor() == w_dt.unquantized().into()
580                        && p.1.precursor() == x_dt.unquantized().into()
581                })
582                .context("No packing found")?;
583            Ok((mmm, packing))
584        }
585    }
586
587    #[allow(clippy::too_many_arguments)]
588    fn wire_mm_weights_bias(
589        &self,
590        model: &mut TypedModel,
591        name: &str,
592        input: OutletId,
593        g_o_ihw: OutletId,
594        bias: OutletId,
595        mmm: Box<dyn MatMatMul>,
596        packing: usize,
597        c_datum_type: DatumType,
598        mmm_output_shape: ShapeFact,
599        k: usize,
600        c_m_axis: usize,
601        c_n_axis: usize,
602    ) -> TractResult<TVec<OutletId>> {
603        ensure!(model.outlet_fact(bias)?.datum_type == mmm.internal_type());
604        let a_pack = &mmm.packings()[packing].0;
605        let packed_ker = self
606            .wire_pack_g_o_ihw(model, name, &**a_pack, g_o_ihw)
607            .context("in kernel_as_packed_as")?;
608        let (mut c_to_a_axis_mapping, mut c_to_b_axis_mapping) = (tvec!(), tvec!());
609
610        c_to_a_axis_mapping.push((c_m_axis - 1, 0)); // Group
611        c_to_b_axis_mapping.push((0, 0)); // Batch
612        c_to_b_axis_mapping.push((c_m_axis - 1, 1)); // Group
613
614        let geo = AddMatMulGeometry {
615            k: k.to_dim(),
616            c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
617            c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
618        };
619        let mut ops: Vec<ProtoFusedSpec> =
620            vec![ProtoFusedSpec::AddMatMul { geo, a: 1, b: 0, packings: vec![(packing, None)] }];
621        let mut wires: TVec<OutletId> = tvec!(input, packed_ker);
622        let bias_fact = model.outlet_fact(bias)?;
623        if bias_fact.konst.is_none() || !bias_fact.konst.as_ref().unwrap().is_all_zero()? {
624            let (fused, bias) = self.wire_bias_as_non_linear(model, name, bias, c_m_axis - 1)?;
625            wires.push(bias);
626            ops.push(fused);
627        }
628        ops.push(ProtoFusedSpec::Store(vec![unsafe {
629            mmm.c_view(Some(c_m_axis), Some(c_n_axis))
630        }]));
631        model.wire_node(
632            format!("{name}.matmatmul"),
633            OptMatMul::new(
634                vec![mmm],
635                ModePicker::Single,
636                c_datum_type.fact(mmm_output_shape),
637                Some(c_m_axis),
638                Some(c_n_axis),
639                ops,
640                packing == 0 && self.group == 1,
641            )?,
642            &wires,
643        )
644    }
645
646    pub fn wire_as_depth_wise(
647        &self,
648        model: &mut TypedModel,
649        name: &str,
650        wire: &[OutletId],
651    ) -> TractResult<OutletId> {
652        let &[x, kernel, mut bias] = wire else { bail!("Wrong number of inputs") };
653        let x_fact = model.outlet_fact(x)?.clone();
654        let x_shape = x_fact.shape.as_concrete().unwrap();
655        let ConcretePoolGeometry { input_shape, patch, output_shape } =
656            self.pool_spec.compute_geo(&x_fact.shape)?.to_concrete(x_shape)?.into_owned();
657        let kernel = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
658        let c_axis = self.pool_spec.data_format.shape(x_shape)?.c_axis();
659        bias = wire_reshape_bias_for_bin(
660            model,
661            name,
662            bias,
663            x_fact.rank(),
664            c_axis,
665            self.output_channels(),
666        )?[0];
667        let op = DepthWise::new(patch, input_shape, output_shape);
668        Ok(model.wire_node(name, op, &[x, kernel[0], bias])?[0])
669    }
670
671    /// Eligibility for the direct register-blocked conv (see `blocked.rs`):
672    /// f32 NCHW, kernel width 1 (extent on H only), unit stride/dilation on the
673    /// contiguous W axis, grouped with a *small* number of out-channels per group
674    /// (where the im2col matmul's M-tile would be mostly wasted). Concrete shape
675    /// required. Returns the fully-parameterised op, or None to fall back.
676    fn try_blocked_conv(&self, input_fact: &TypedFact) -> Option<super::BlockedConv> {
677        // The direct blocked conv beats im2col on wasm (no AMX; the gather +
678        // wasted-M-tile matmul is slow) but LOSES on native, where shape-aware
679        // AMX dispatch already handles the tiny-M matmul well. So: on by default
680        // on wasm, opt-in on native. Env overrides either way for A/B.
681        let enabled = if cfg!(target_family = "wasm") {
682            std::env::var("TRACT_DISABLE_BLOCKED_CONV").is_err()
683        } else {
684            std::env::var("TRACT_ENABLE_BLOCKED_CONV").is_ok()
685        };
686        if !enabled {
687            return None;
688        }
689        if self.q_params.is_some() {
690            return None;
691        }
692        if input_fact.datum_type != f32::datum_type() {
693            return None;
694        }
695        if self.pool_spec.data_format != crate::ops::nn::DataFormat::NCHW {
696            return None;
697        }
698        if self.pool_spec.rank() != 2 || self.pool_spec.kernel_shape[1] != 1 {
699            return None;
700        }
701        if self.pool_spec.stride(1) != 1 || self.pool_spec.dilation(1) != 1 {
702            return None;
703        }
704        let group = self.group;
705        let oc = self.output_channels();
706        let c_in = self.input_channels();
707        if group == 0 || !oc.is_multiple_of(group) || !c_in.is_multiple_of(group) {
708            return None;
709        }
710        let ocg = oc / group;
711        // Win condition: tiny per-group output count makes the im2col matmul's
712        // m-tile wasteful. Large ocg packs the tile fine — leave it to im2col.
713        if ocg == 0 || ocg > 8 {
714            return None;
715        }
716        let concrete = input_fact.shape.as_concrete()?;
717        let shape = self.pool_spec.data_format.shape(concrete).ok()?;
718        let h_axis = shape.h_axis();
719        let h_in = concrete[h_axis];
720        let w = concrete[h_axis + 1];
721        let pads = self.pool_spec.computed_padding(shape.hw_dims());
722        Some(super::BlockedConv {
723            n: *shape.n().unwrap_or(&1),
724            c_in,
725            h_in,
726            w,
727            oc,
728            group,
729            kh: self.pool_spec.kernel_shape[0],
730            stride_h: self.pool_spec.stride(0),
731            dil_h: self.pool_spec.dilation(0),
732            pad_before_h: pads[0].pad_before,
733            h_out: pads[0].convoluted,
734        })
735    }
736
737    fn wire_as_blocked_conv(
738        &self,
739        model: &mut TypedModel,
740        name: &str,
741        wire: &[OutletId],
742        op: super::BlockedConv,
743    ) -> TractResult<OutletId> {
744        let &[x, kernel, bias] = wire else { bail!("Wrong number of inputs") };
745        // Kernel → [group, ocg, icg·kh] (group-major, i-major/h-minor); its flat
746        // layout is exactly the [oc, icg·kh] the op indexes.
747        let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
748        Ok(model.wire_node(name, op, &[x, g_o_ihw[0], bias])?[0])
749    }
750
751    fn declutter_stride_slice_to_downsample(
752        &self,
753        model: &TypedModel,
754        node: &TypedNode,
755    ) -> TractResult<Option<TypedModelPatch>> {
756        let spatial_rank = self.pool_spec.rank();
757        if let Some(axis) = (0..spatial_rank).find(|&ax| {
758            self.pool_spec.stride(ax) > 1
759                && self.pool_spec.padding.valid_dim(ax, self.pool_spec.stride(ax) == 1)
760                && (self.pool_spec.kernel_shape[ax] == 1
761                    || self.pool_spec.dilation(ax).is_multiple_of(self.pool_spec.stride(ax)))
762        }) {
763            let input_fact = model.outlet_fact(node.inputs[0])?;
764            let downsample_factor = self.pool_spec.stride(axis);
765            let mut new_op = self.clone();
766            if new_op.pool_spec.dilation(axis) > 1 {
767                new_op.pool_spec.dilations.as_mut().unwrap()[axis] =
768                    new_op.pool_spec.dilations.as_mut().unwrap()[axis].divceil(downsample_factor);
769            }
770            new_op.pool_spec.strides.as_mut().unwrap()[axis] /= downsample_factor;
771            let mut patch = TypedModelPatch::default();
772            let mut taps = patch.taps(model, &node.inputs)?;
773            let shape = self.pool_spec.data_format.shape(&input_fact.shape)?;
774            taps[0] = patch.wire_node(
775                format!("{}.downsample.{}", node.name, axis),
776                crate::ops::Downsample::new(axis + shape.h_axis(), downsample_factor as isize, 0),
777                &[taps[0]],
778            )?[0];
779            let id = patch.wire_node(&*node.name, new_op, &taps)?[0];
780            patch.shunt_outside(model, OutletId::new(node.id, 0), id)?;
781            return Ok(Some(patch));
782        }
783        Ok(None)
784    }
785
786    fn declutter_as_einsum(
787        &self,
788        model: &TypedModel,
789        node: &TypedNode,
790    ) -> TractResult<Option<TypedModelPatch>> {
791        let (input_facts, output_facts) = model.node_facts(node.id)?;
792        let full_input_shape = input_facts[0].shape.to_tvec();
793        let input_shape = self.pool_spec.data_format.shape(&full_input_shape)?;
794        if self.group == 1
795            && self.pool_spec.strides().iter().all(|s| *s == 1)
796            && self.pool_spec.dilations().iter().all(|d| *d == 1)
797            && self.pool_spec.kernel_shape.iter().product::<usize>() == 1
798            && self
799                .pool_spec
800                .computed_padding(input_shape.hw_dims())
801                .iter()
802                .all(|pad| pad.pad_after.is_zero() && pad.pad_before.is_zero())
803        {
804            let mut axes = self.axes_mapping(&input_facts, &output_facts)?;
805            let mut patch = TypedModelPatch::new("declutter_as_einsum");
806            let mut taps = patch.taps(model, &node.inputs)?;
807            let name = &node.name;
808            let co = self.output_channels();
809            taps[1] =
810                self.wire_kernel_as_g_o_ihw(&mut patch, &format!("{name}.filters"), taps[1])?[0];
811            taps[1] =
812                patch.wire_node(format!("{name}.filters_as_co_ci"), AxisOp::Rm(0), &[taps[1]])?[0];
813
814            while axes.rank(InOut::In(1)) > 0 {
815                axes = axes.remove_axis_occurency(InOut::In(1), 0)?;
816            }
817            axes = axes
818                .with_extra_axis_occurency('O', InOut::In(1), 0)?
819                .with_extra_axis_occurency('I', InOut::In(1), 1)?;
820
821            let bias_fact = input_facts[2];
822            let wire = if self.q_params.is_some() {
823                if bias_fact.rank() == 1 {
824                    axes = axes.linking('O', (InOut::In(2), 0))?;
825                }
826                let op = EinSum { axes, operating_dt: i32::datum_type(), q_params: self.q_params };
827                patch.wire_node(format!("{name}.einsum"), op, &taps)?[0]
828            } else {
829                axes = axes.remove_slot(InOut::In(2))?;
830                let op = EinSum { axes, operating_dt: input_facts[0].datum_type, q_params: None };
831                let mut wire = patch.wire_node(format!("{name}.einsum"), op, &taps[0..2])?[0];
832
833                if !bias_fact.konst.as_ref().map(|f| f.is_zero()).transpose()?.unwrap_or(false) {
834                    let bias_current_shape =
835                        if bias_fact.rank() == 0 { tvec!() } else { tvec!(co.to_dim()) };
836                    let mut bias_shape = tvec!(1.to_dim(); input_shape.rank());
837                    if bias_fact.rank() > 0 {
838                        bias_shape[input_shape.c_axis()] = co.to_dim();
839                    }
840                    let b = patch.wire_node(
841                        format!("{name}.bias.reshape"),
842                        AxisOp::Reshape(0, bias_current_shape, bias_shape),
843                        &[taps[2]],
844                    )?[0];
845                    wire = patch.wire_node(
846                        format!("{name}.bias"),
847                        crate::ops::math::add(),
848                        &[wire, b],
849                    )?[0];
850                }
851                wire
852            };
853            patch.node_mut(wire.node).name = node.name.to_string();
854            patch.shunt_outside(model, node.id.into(), wire)?;
855            return Ok(Some(patch));
856        }
857        Ok(None)
858    }
859
860    fn declutter_precursor_padding(
861        &self,
862        model: &TypedModel,
863        node: &TypedNode,
864    ) -> TractResult<Option<TypedModelPatch>> {
865        rule_if!(!matches!(
866            self.pool_spec.padding,
867            ExplicitOnnxPool(_, _, _) | SameLower | SameUpper
868        ));
869        let prec = model.node(node.inputs[0].node);
870        rule_if_some!(pad = prec.op_as::<Pad>());
871        rule_if_let!(PadMode::Constant(value) = &pad.mode);
872        let shape = self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
873        rule_if!(value.is_zero()?);
874        rule_if!(pad.pads[shape.c_axis()] == (0, 0));
875        if self.pool_spec.data_format.has_n() {
876            rule_if!(pad.pads[0] == (0, 0));
877        }
878        let mut before: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.0).collect();
879        let mut after: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.1).collect();
880        if let Explicit(bef, aft) = &self.pool_spec.padding {
881            izip!(&mut before, bef).for_each(|(pad, cv)| *pad += cv);
882            izip!(&mut after, aft).for_each(|(pad, cv)| *pad += cv);
883        }
884        let padding = Explicit(before, after);
885        let mut new = self.clone();
886        new.pool_spec.padding = padding;
887        let mut patch = TypedModelPatch::default();
888        let mut wire = patch.taps(model, &node.inputs)?;
889        wire[0] = patch.tap_model(model, prec.inputs[0])?;
890        let wire = patch.wire_node(&node.name, new, &wire)?;
891        patch.shunt_outside(model, node.id.into(), wire[0])?;
892        Ok(Some(patch))
893    }
894
895    fn declutter_channel_arithmetic_succ(
896        &self,
897        model: &TypedModel,
898        node: &TypedNode,
899    ) -> TractResult<Option<TypedModelPatch>> {
900        rule_if!(self.q_params.is_none());
901        rule_if!(self.group == 1);
902        rule_if_let!(&[succ_outlet] = &*node.outputs[0].successors);
903        let succ = model.node(succ_outlet.node);
904        rule_if_some!(bin = succ.op_as::<TypedBinOp>());
905        let other_input = succ.inputs[1 - succ_outlet.slot];
906        let axes_mapping = model.node_axes_mapping(succ.id)?;
907        let input_shape =
908            self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
909        let conv_c_axis = input_shape.c_axis();
910        rule_if!(
911            axes_mapping.axis((InOut::In(succ_outlet.slot), conv_c_axis))?.inputs
912                [1 - succ_outlet.slot]
913                .len()
914                == 1
915        );
916        let mut other_expected_shape = tvec!(1.to_dim(); input_shape.rank());
917        other_expected_shape[conv_c_axis] = self.output_channels().to_dim();
918        rule_if!(*other_expected_shape == *model.outlet_fact(other_input)?.shape);
919
920        let mut patch = TypedModelPatch::default();
921        let [input, mut kernel, mut bias] = *patch.taps(model, &node.inputs)? else {
922            panic!("Expect three inputs");
923        };
924        let name = &node.name;
925        let succ_name = &succ.name;
926
927        let operand = patch.tap_model(model, other_input)?;
928
929        let renamed_bias = format!("{name}.{succ_name}.bias");
930        let renamed_kernel = format!("{name}.{succ_name}.kernel");
931        bias = wire_reshape_bias_for_bin(
932            &mut patch,
933            format!("{renamed_bias}.reshape"),
934            bias,
935            1,
936            0,
937            self.output_channels(),
938        )?[0];
939
940        let operand = wire_reshape_bias_for_bin(
941            &mut patch,
942            format!("{renamed_bias}.reshape_operand"),
943            operand,
944            1,
945            0,
946            self.output_channels(),
947        )?[0];
948
949        let operand_fact = patch.outlet_fact(operand)?.shape.to_tvec();
950        let kernel_fact = patch.outlet_fact(kernel)?;
951        let mut operand_shape_for_kernel = tvec!(1.to_dim(); 2 + input_shape.hw_rank());
952        operand_shape_for_kernel[self.kernel_fmt.o_axis(&kernel_fact.shape)] =
953            self.output_channels().to_dim();
954        let operand_for_kernel = patch.wire_node(
955            format!("{renamed_kernel}.reshape_operand"),
956            AxisOp::Reshape(0, operand_fact, operand_shape_for_kernel),
957            &[operand],
958        )?[0];
959
960        if bin.0.is::<Sub>() && succ_outlet.slot == 0 {
961            bias = patch.wire_node(&renamed_bias, sub(), &[bias, operand])?[0];
962        } else if bin.0.is::<Sub>() {
963            bias = patch.wire_node(&renamed_bias, sub(), &[operand, bias])?[0];
964        } else if bin.0.is::<Div>() && succ_outlet.slot == 0 {
965            bias = patch.wire_node(&renamed_bias, div(), &[bias, operand])?[0];
966            kernel = patch.wire_node(&renamed_kernel, div(), &[kernel, operand_for_kernel])?[0];
967        } else if bin.0.is::<Div>() {
968            bias = patch.wire_node(&renamed_bias, div(), &[operand, bias])?[0];
969            kernel = patch.wire_node(&renamed_kernel, div(), &[operand_for_kernel, kernel])?[0];
970        } else if bin.0.is::<Add>() {
971            bias = patch.wire_node(&renamed_bias, add(), &[bias, operand])?[0];
972        } else if bin.0.is::<Mul>() {
973            bias = patch.wire_node(&renamed_bias, mul(), &[bias, operand])?[0];
974            kernel = patch.wire_node(&renamed_kernel, mul(), &[kernel, operand_for_kernel])?[0];
975        } else {
976            return Ok(None);
977        };
978        let wire = patch.wire_node(&node.name, self.clone(), &[input, kernel, bias])?[0];
979        patch.shunt_outside(model, succ_outlet.node.into(), wire)?;
980        Ok(Some(patch))
981    }
982}
983
984impl Op for Conv {
985    fn name(&self) -> StaticName {
986        "Conv".into()
987    }
988
989    fn info(&self) -> TractResult<Vec<String>> {
990        let mut info = self.pool_spec.info();
991        info.push(format!("Kernel {:?} (groups:{})", self.kernel_fmt, self.group));
992        Ok(info)
993    }
994
995    fn validation(&self) -> Validation {
996        Validation::Rounding
997    }
998
999    op_as_typed_op!();
1000}
1001
1002impl EvalOp for Conv {
1003    fn is_stateless(&self) -> bool {
1004        true
1005    }
1006
1007    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
1008        let mut model = TypedModel::default();
1009        let wire: TVec<OutletId> = inputs
1010            .iter()
1011            .enumerate()
1012            .map(|(ix, v)| model.add_source(format!("source.{ix}"), v.datum_type().fact(v.shape())))
1013            .collect::<TractResult<_>>()?;
1014        let wire = unsafe {
1015            if self.q_params.is_some() {
1016                self.wire_as_quant_im2col(&mut model, "im2col-adhoc", &wire)?
1017            } else {
1018                self.wire_as_im2col_pair(&mut model, "im2col-adhoc", &wire)?
1019            }
1020        };
1021        model.select_output_outlets(&wire)?;
1022        model.into_runnable()?.run(inputs)
1023    }
1024}
1025
1026impl TypedOp for Conv {
1027    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
1028        ensure!(self.q_params.is_some() || inputs[0].datum_type.is_float());
1029        let q_inputs = if self.q_params.is_some() { 6 } else { 0 };
1030        ensure!(inputs[1].datum_type.is_number() || self.kernel_fmt == KernelFormat::OIHW);
1031        if inputs.len() != 3 + q_inputs {
1032            bail!("Wrong number of inputs: expected {} got {}", 3 + q_inputs, inputs.len());
1033        }
1034        if self.q_params.is_some() {
1035            ensure!(inputs[2].datum_type == i32::datum_type());
1036            ensure!(inputs[3].datum_type == i32::datum_type());
1037            ensure!(inputs[4].datum_type.is_float());
1038            ensure!(inputs[5].datum_type == i32::datum_type());
1039            ensure!(inputs[6].datum_type.is_float());
1040            ensure!(inputs[7].datum_type == i32::datum_type());
1041            ensure!(inputs[8].datum_type.is_float());
1042        }
1043        ensure!(self.pool_spec.rank() + 2 == inputs[1].shape.len());
1044        if self.pool_spec.data_format.shape(&*inputs[0].shape)?.c()
1045            != &self.input_channels().to_dim()
1046        {
1047            bail!(
1048                "Inconsistent convolution: input is {:?}, but kernel expects {} input channels.\n{:?}",
1049                inputs[0],
1050                self.input_channels(),
1051                self
1052            );
1053        }
1054        if let ExplicitOnnxPool(bef, after, _) | Explicit(bef, after) = &self.pool_spec.padding {
1055            anyhow::ensure!(bef.len() == self.pool_spec.rank());
1056            anyhow::ensure!(after.len() == self.pool_spec.rank());
1057        }
1058        ensure!(
1059            inputs[2].rank() == 0
1060                || (inputs[2].rank() == 1
1061                    && inputs[2].shape.volume() == self.output_channels().to_dim()),
1062            "Bias should be scalar or a vector with one value per output channel. Output channels is {}, bias is {:?}",
1063            self.output_channels(),
1064            inputs[2]
1065        );
1066        let mut fact = self.pool_spec.output_facts(inputs)?.remove(0);
1067        if let Some(dt) = self.q_params {
1068            fact.datum_type = dt;
1069        } else {
1070            ensure!(
1071                inputs[1].is_exotic() || inputs[0].datum_type == inputs[1].datum_type,
1072                "Convolution input, weights and bias must have the same type, got {inputs:?}",
1073            )
1074        }
1075        Ok(tvec!(fact))
1076    }
1077
1078    fn axes_mapping(
1079        &self,
1080        inputs: &[&TypedFact],
1081        outputs: &[&TypedFact],
1082    ) -> TractResult<AxesMapping> {
1083        let fact = &inputs[0];
1084        let shape = self.pool_spec.data_format.shape(&fact.shape)?;
1085        let mut axes = AxesMapping::disconnected(inputs, outputs)?
1086            .renaming((InOut::In(0), shape.c_axis()), 'I')?
1087            .renaming((InOut::Out(0), shape.c_axis()), 'O')?;
1088        if let Some(n_axis) = shape.n_axis() {
1089            axes = axes
1090                .renaming((InOut::In(0), n_axis), 'N')?
1091                .linking('N', (InOut::Out(0), n_axis))?;
1092        }
1093        let h_axis = shape.h_axis();
1094        let geo = "HWXYZ".chars().chain('a'..);
1095        let kernel_spatial_shape = &self.pool_spec.kernel_shape;
1096        let padding = self.pool_spec.computed_padding(shape.hw_dims());
1097        for ((ix, &dim), repr) in kernel_spatial_shape.iter().enumerate().zip(geo) {
1098            if dim == 1
1099                && self.pool_spec.dilation(ix) == 1
1100                && self.pool_spec.stride(ix) == 1
1101                && padding[ix].pad_before.is_zero()
1102                && padding[ix].pad_after.is_zero()
1103            {
1104                axes = axes
1105                    .renaming((InOut::In(0), ix + h_axis), repr)?
1106                    .linking(repr, (InOut::Out(0), ix + h_axis))?;
1107            }
1108        }
1109        if self.q_params.is_some() {
1110            for (qp_ix, qp) in inputs.iter().enumerate().skip(3) {
1111                if qp.rank() == 1 {
1112                    axes = match qp_ix {
1113                        3 | 4 => axes.linking('I', (InOut::In(qp_ix), 0))?,
1114                        5 | 6 => axes.linking('O', (InOut::In(qp_ix), 0))?,
1115                        7 | 8 => axes.linking('O', (InOut::In(qp_ix), 0))?,
1116                        _ => unreachable!(),
1117                    };
1118                }
1119            }
1120        }
1121        Ok(axes)
1122    }
1123
1124    fn declutter(
1125        &self,
1126        model: &TypedModel,
1127        node: &TypedNode,
1128    ) -> TractResult<Option<TypedModelPatch>> {
1129        macro_rules! pass {
1130            ($func:ident) => {
1131                if let Some(mut r) = self.$func(model, node).context(stringify!($func))? {
1132                    trace!(stringify!($func));
1133                    r.push_context(stringify!($func));
1134                    return Ok(Some(r));
1135                }
1136            };
1137        }
1138        pass!(declutter_stride_slice_to_downsample);
1139        pass!(declutter_as_einsum);
1140        pass!(declutter_channel_arithmetic_succ);
1141        pass!(declutter_precursor_padding);
1142        Ok(None)
1143    }
1144
1145    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
1146        let shape = self.pool_spec.data_format.shape(inputs[0].shape.to_tvec())?;
1147        let kernel_spatial_shape = &self.pool_spec.kernel_shape;
1148        let output_dims = self.pool_spec.padding.compute(
1149            shape.hw_dims(),
1150            kernel_spatial_shape,
1151            &self
1152                .pool_spec
1153                .dilations
1154                .clone()
1155                .unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
1156            &self.pool_spec.strides.clone().unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
1157        );
1158        let n_output_points: TDim =
1159            output_dims.iter().map(|d| d.convoluted.clone()).product::<TDim>();
1160        let n_output_channels = self.output_channels().to_dim();
1161        let kernel_surface = kernel_spatial_shape.iter().product::<usize>().to_dim();
1162        let one = 1.to_dim();
1163        Ok(tvec!((
1164            Cost::FMA(inputs[0].datum_type),
1165            shape.n().cloned().unwrap_or(one)
1166                * shape.c()
1167                * n_output_channels
1168                * n_output_points
1169                * kernel_surface
1170                / self.group
1171        )))
1172    }
1173
1174    fn change_axes(
1175        &self,
1176        model: &TypedModel,
1177        node: &TypedNode,
1178        io: InOut,
1179        change: &AxisOp,
1180    ) -> TractResult<Option<AxisChangeConsequence>> {
1181        rule_if!(io != InOut::In(1));
1182        if io == InOut::In(2)
1183            && let &AxisOp::Rm(_) = change
1184        {
1185            return Ok(Some(AxisChangeConsequence {
1186                substitute_op: Some(Box::new(self.clone())),
1187                wire_changes: tvec!(),
1188            }));
1189        }
1190        let full_input_shape = model.outlet_fact(node.inputs[0])?.shape.to_tvec();
1191        let shape = self.pool_spec.data_format.shape(full_input_shape.clone())?;
1192        // remove n
1193        if let Some(n) = shape.n_axis() {
1194            assert_eq!(n, 0);
1195            if change == &AxisOp::Rm(n) {
1196                let op = Conv { pool_spec: self.pool_spec.dispose_n_axis(), ..self.clone() };
1197                return Ok(Some(AxisChangeConsequence {
1198                    substitute_op: Some(Box::new(op)),
1199                    wire_changes: tvec!(
1200                        (InOut::In(0), change.clone()),
1201                        (InOut::Out(0), change.clone())
1202                    ),
1203                }));
1204            }
1205            rule_if!(change.transform_axis(n).map(|axis| axis == 0).unwrap_or(false));
1206        }
1207        // format swap: chw <-> hwc
1208        let (new_format, axis_move) = match self.pool_spec.data_format {
1209            DataFormat::NCHW => {
1210                (DataFormat::NHWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
1211            }
1212            DataFormat::CHW => {
1213                (DataFormat::HWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
1214            }
1215            DataFormat::NHWC => (DataFormat::NCHW, AxisOp::Move(shape.c_axis(), 1)),
1216            DataFormat::HWC => (DataFormat::CHW, AxisOp::Move(shape.c_axis(), 0)),
1217        };
1218        if *change == axis_move {
1219            let mut new_op = self.clone();
1220            new_op.pool_spec.data_format = new_format;
1221            return Ok(Some(AxisChangeConsequence {
1222                substitute_op: Some(Box::new(new_op)),
1223                wire_changes: tvec!(
1224                    (InOut::In(0), change.clone()),
1225                    (InOut::Out(0), change.clone())
1226                ),
1227            }));
1228        }
1229        // geo axis manips
1230        rule_if!(!model.node_input_facts(node.id)?[1].is_exotic());
1231        use AxisOp::*;
1232        let h_axis = shape.h_axis();
1233        let hw_axes = shape.hw_axes();
1234        let kh_axis = self.kernel_fmt.h_axis();
1235        let (geo_adjusted, kernel_adjusted) = match change {
1236            Rm(a)
1237                if hw_axes.contains(a)
1238                    && hw_axes.len() > 1
1239                    && self.pool_spec.dilation(a - h_axis) == 1
1240                    && self.pool_spec.stride(a - h_axis) == 1
1241                    && self.pool_spec.kernel_shape[a - h_axis] == 1 =>
1242            {
1243                let geo_axis = a - h_axis;
1244                (Rm(geo_axis), Rm(kh_axis + geo_axis))
1245            }
1246            Add(a) if hw_axes.contains(a) => (Add(a - h_axis), Add(a - h_axis + kh_axis)),
1247            Move(f, t) if hw_axes.contains(f) && hw_axes.contains(t) => {
1248                (Move(f - h_axis, t - h_axis), Move(f - h_axis + kh_axis, t - h_axis + kh_axis))
1249            }
1250            _ => return Ok(None),
1251        };
1252        let pool_spec = self.pool_spec.change_geo_axes(&geo_adjusted)?;
1253        let new_op = Conv { pool_spec, ..self.clone() };
1254        Ok(Some(AxisChangeConsequence {
1255            substitute_op: Some(Box::new(new_op)),
1256            wire_changes: tvec!(
1257                (InOut::In(0), change.clone()),
1258                (InOut::In(1), kernel_adjusted),
1259                (InOut::Out(0), change.clone())
1260            ),
1261        }))
1262    }
1263
1264    fn codegen(
1265        &self,
1266        model: &TypedModel,
1267        node: &TypedNode,
1268    ) -> TractResult<Option<TypedModelPatch>> {
1269        let input_fact = model.outlet_fact(node.inputs[0])?;
1270        unsafe {
1271            if self.q_params.is_some() {
1272                let mut patch = TypedModelPatch::new("quantized-codegen");
1273                let inputs = patch.taps(model, &node.inputs)?;
1274                let wire = self
1275                    .wire_as_quant_im2col(&mut patch, &node.name, &inputs)
1276                    .context("in wire_as_quant_im2col")?;
1277                patch.shunt_outside(model, node.id.into(), wire[0])?;
1278                patch.obliterate(node.id)?;
1279                Ok(Some(patch))
1280            } else if let Some(op) = self.try_blocked_conv(input_fact) {
1281                // Direct register-blocked conv for the small-ocg NCHW kw=1 class;
1282                // beats lazy im2col by avoiding the gather + wasted M-tile matmul.
1283                let mut patch = TypedModelPatch::new("blocked-conv");
1284                let inputs = patch.taps(model, &node.inputs)?;
1285                let wire = self
1286                    .wire_as_blocked_conv(&mut patch, &node.name, &inputs, op)
1287                    .context("wire_as_blocked_conv")?;
1288                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1289                patch.obliterate(node.id)?;
1290                Ok(Some(patch))
1291            } else if input_fact
1292                .shape
1293                .as_concrete()
1294                .map(|s| should_use_lazy(&self.pool_spec, self.group, s, input_fact.datum_type))
1295                .unwrap_or(false)
1296            {
1297                let mut patch = TypedModelPatch::new("lazy-im2col");
1298                let inputs = patch.taps(model, &node.inputs)?;
1299                let wire = self
1300                    .wire_as_lazy_im2col(&mut patch, &node.name, &inputs)
1301                    .context("wire_as_lazy_im2col")?[0];
1302                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1303                patch.obliterate(node.id)?;
1304                Ok(Some(patch))
1305            } else if self.group != 1
1306                && self.group == self.output_channels()
1307                && self.group == self.input_channels()
1308                && input_fact.shape.as_concrete().is_some()
1309            {
1310                let mut patch = TypedModelPatch::new("depth_wise");
1311                let inputs = patch.taps(model, &node.inputs)?;
1312                let wire = self
1313                    .wire_as_depth_wise(&mut patch, &node.name, &inputs)
1314                    .context("wire_as_depth_wise")?;
1315                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1316                patch.obliterate(node.id)?;
1317                Ok(Some(patch))
1318            } else {
1319                let mut patch = TypedModelPatch::new("im2col");
1320                let inputs = patch.taps(model, &node.inputs)?;
1321                let wire = self
1322                    .wire_as_im2col_pair(&mut patch, &node.name, &inputs)
1323                    .context("in wire_as_im2col_pair")?[0];
1324                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1325                patch.obliterate(node.id)?;
1326                Ok(Some(patch))
1327            }
1328        }
1329    }
1330
1331    as_op!();
1332}
1333
1334/// Default minimum kernel volume for picking LazyIm2col over eager Im2col.
1335///
1336/// LazyIm2col has per-output-position gather indirection overhead; eager Im2col has
1337/// materialisation overhead (one big alloc + strided memcpy). For tiny kernels the
1338/// indirection wins; for bigger kernels the materialisation cost dominates. This default
1339/// is conservative — empirically lazy already wins for kernel volumes ≥ 4 on Apple AMX
1340/// (and likely lower on memory-constrained targets like embedded ARM). Override via
1341/// `TRACT_LAZY_IM2COL_MIN_KERNEL` env var to experiment with lower thresholds.
1342const DEFAULT_LAZY_IM2COL_MIN_KERNEL: usize = 6;
1343
1344fn lazy_im2col_min_kernel() -> usize {
1345    use std::sync::OnceLock;
1346    static V: OnceLock<usize> = OnceLock::new();
1347    *V.get_or_init(|| {
1348        std::env::var("TRACT_LAZY_IM2COL_MIN_KERNEL")
1349            .ok()
1350            .and_then(|s| s.parse::<usize>().ok())
1351            .unwrap_or(DEFAULT_LAZY_IM2COL_MIN_KERNEL)
1352    })
1353}
1354
1355/// Default eager-Im2col scratch-size ceiling, in bytes, above which LazyIm2col is
1356/// preferred regardless of kernel volume.
1357///
1358/// Eager Im2col materialises a `[k, n]` packed scratch of `k·n·sizeof` bytes — it is
1359/// allocated, written, then read back by the matmul. While that scratch is small it
1360/// stays hot in cache and the round-trip is cheap, so the kernel-volume rule above
1361/// governs. Once it is large, the materialisation becomes a pure memory-bandwidth tax
1362/// (write + read of multiple MB every inference) that outweighs LazyIm2col's per-panel
1363/// gather indirection — so prefer lazy. The kernel-volume rule alone misses this case:
1364/// a *small* kernel over a *large* output (big `n`) still materialises multiple MB.
1365///
1366/// The crossover is target-dependent. On WASM the materialisation tax bites harder
1367/// (no hardware-prefetch help, bounds-checked stores), so lazy wins from ~1 MiB of
1368/// scratch upward. On native CPUs the caches and prefetchers absorb a few MB, so the
1369/// crossover sits higher (~4 MiB, measured on Apple Silicon). Hence the per-family
1370/// defaults below. Override on either target via `TRACT_LAZY_IM2COL_MAX_EAGER_BYTES`;
1371/// this value is the key knob for the canary-model regression gate.
1372#[cfg(target_family = "wasm")]
1373const DEFAULT_LAZY_IM2COL_MAX_EAGER_BYTES: usize = 1024 * 1024;
1374#[cfg(not(target_family = "wasm"))]
1375const DEFAULT_LAZY_IM2COL_MAX_EAGER_BYTES: usize = 4 * 1024 * 1024;
1376
1377fn lazy_im2col_max_eager_bytes() -> usize {
1378    use std::sync::OnceLock;
1379    static V: OnceLock<usize> = OnceLock::new();
1380    *V.get_or_init(|| {
1381        std::env::var("TRACT_LAZY_IM2COL_MAX_EAGER_BYTES")
1382            .ok()
1383            .and_then(|s| s.parse::<usize>().ok())
1384            .unwrap_or(DEFAULT_LAZY_IM2COL_MAX_EAGER_BYTES)
1385    })
1386}
1387
1388fn should_use_lazy(
1389    pool_spec: &PoolSpec,
1390    group: usize,
1391    input_shape: &[usize],
1392    dt: DatumType,
1393) -> bool {
1394    // Depthwise convs (group == in_channels == out_channels) have a specialised
1395    // `DepthWise` op downstream that's much faster than the generic im2col + matmul
1396    // path on every backend we measured (Apple AMX, x64, aarch64). Don't intercept
1397    // them here — let the dispatch in `conv.rs` reach `wire_as_depth_wise`.
1398    let is_depthwise =
1399        group > 1 && group == pool_spec.input_channels && group == pool_spec.output_channels;
1400    if is_depthwise {
1401        return false;
1402    }
1403    let Ok(output_shape) = pool_spec.output_shape(input_shape) else { return false };
1404    // LazyIm2col's offset tables are built for a single batch.
1405    if output_shape.n().unwrap_or(&1) != &1 {
1406        return false;
1407    }
1408    let kernel_volume = pool_spec.kernel_shape.iter().product::<usize>();
1409    // Primary rule: kernel volume. LazyIm2col's per-output-position gather indirection
1410    // is cheap relative to materialising the scratch for a sizeable kernel.
1411    if kernel_volume >= lazy_im2col_min_kernel() {
1412        return true;
1413    }
1414    // Shape-aware rule: prefer lazy when the eager scratch (`k·n·sizeof`) is large,
1415    // even for a small kernel. `n` is the output spatial volume — the dimension the
1416    // kernel-volume rule ignores but which actually drives the materialisation cost.
1417    let n: usize = output_shape.hw_dims().iter().product();
1418    let k = pool_spec.input_channels * kernel_volume / group;
1419    let eager_scratch_bytes = k.saturating_mul(n).saturating_mul(dt.size_of());
1420    eager_scratch_bytes >= lazy_im2col_max_eager_bytes()
1421}
1422
1423#[allow(non_snake_case)]
1424#[cfg(test)]
1425mod test {
1426    use super::*;
1427    use crate::ops::array::Pad;
1428    use DataFormat::*;
1429
1430    #[test]
1431    fn onnx_basic_convinteger() {
1432        let op = Conv {
1433            pool_spec: PoolSpec {
1434                data_format: NCHW,
1435                kernel_shape: tvec!(2, 2),
1436                padding: Valid,
1437                dilations: None,
1438                strides: None,
1439                input_channels: 1,
1440                output_channels: 1,
1441            },
1442            kernel_fmt: KernelFormat::OIHW,
1443            group: 1,
1444            q_params: Some(i32::datum_type()),
1445        };
1446        let input = tvec!(
1447            rctensor4(&[[[[1u8, 2, 3], [4, 5, 6], [7, 8, 9]]]]),
1448            rctensor4(&[[[[1u8, 1], [1, 1]]]]),
1449            rctensor0(0u32),
1450            rctensor0(1u8),
1451            rctensor0(1.0f32),
1452            rctensor0(0u8),
1453            rctensor0(1.0f32),
1454            rctensor0(0i32),
1455            rctensor0(1.0f32),
1456        );
1457        let input = input.into_iter().map(IntoTValue::into_tvalue).collect::<TVec<_>>();
1458        let output = op.eval(input).unwrap();
1459        assert_eq!(*output[0], tensor4(&[[[[8i32, 12], [20, 24]]]]));
1460    }
1461
1462    #[test]
1463    fn valid_conv_absorbs_precursor_pad() -> TractResult<()> {
1464        let mut model = TypedModel::default();
1465        let wire = tvec!(model.add_source("source", f32::fact(dims!(1, 10)))?);
1466        let wire = model.wire_node(
1467            "pad",
1468            Pad {
1469                pads: vec![(0, 0), (1, 0)],
1470                mode: ops::array::PadMode::Constant(rctensor0(0f32)),
1471            },
1472            &wire,
1473        )?;
1474        let kernel = model.add_const("kernel", rctensor3(&[[[1f32, 2f32]]]))?;
1475        let bias = model.add_const("bias", rctensor0(0f32))?;
1476        let wire = model.wire_node(
1477            "conv",
1478            Conv {
1479                pool_spec: PoolSpec {
1480                    data_format: crate::ops::nn::DataFormat::CHW,
1481                    dilations: None,
1482                    strides: None,
1483                    kernel_shape: tvec![2],
1484                    padding: Explicit(tvec![0], tvec![0]),
1485                    input_channels: 1,
1486                    output_channels: 1,
1487                },
1488                kernel_fmt: crate::ops::cnn::KernelFormat::OIHW,
1489                group: 1,
1490                q_params: None,
1491            },
1492            &[wire[0], kernel, bias],
1493        )?;
1494        model.select_output_outlets(&wire)?;
1495        model.declutter()?;
1496        assert_eq!(model.nodes().len(), 4); // source + conv + kernel + bias
1497        let cv = model.nodes()[3].op_as::<Conv>().unwrap();
1498        assert_eq!(cv.pool_spec.padding, Explicit(tvec![1], tvec![0])); // source + conv
1499        Ok(())
1500    }
1501}