1use tract_data::itertools::izip;
2use tract_linalg::block_quant::{BlockQuantFact, PackedBlockQuantFormat};
3use tract_linalg::WeightType;
4use tract_num_traits::Zero;
5
6use crate::internal::*;
7use crate::model::*;
8use crate::ops;
9use crate::ops::array::Pad;
10use crate::ops::array::PadMode;
11use crate::ops::binary::TypedBinOp;
12use crate::ops::cast::cast;
13use crate::ops::cnn::conv::block_quant::{BlockQuantIntoShape, SplitGroupBlockQuant};
14use crate::ops::cnn::conv::lazy_im2col::LazyIm2Col;
15use crate::ops::cnn::conv::lazy_im2col::LazyIm2colParams;
16use crate::ops::cnn::wire_reshape_bias_for_bin;
17use crate::ops::cnn::PaddingSpec::*;
18use crate::ops::einsum::EinSum;
19use crate::ops::math::{add, div, mul, sub};
20use crate::ops::math::{Add, Div, Mul, Sub};
21use crate::ops::matmul::optimized::AddMatMulGeometry;
22use crate::ops::matmul::optimized::MapOutputAxisToInput;
23use crate::ops::matmul::pack::{OptMatMulPack, OptSimpleMatMulPack};
24use crate::ops::matmul::quant::wire_ensure_q8_flavour;
25use crate::ops::matmul::ModePicker;
26use crate::ops::nn::Reduce;
27
28use super::depth_wise::DepthWise;
29use super::im2col::Im2Col;
30use crate::ops::cnn::conv::{block_quant_aware_weight_shape, KernelFormat};
31use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry, PoolSpec};
32use crate::ops::matmul::optimized::{OptMatMul, ProtoFusedSpec};
33use crate::ops::nn::{BaseDataShape, DataFormat, DataShape};
34
35use tract_linalg::mmm::{MMMInputFormat, MatMatMul};
36use tract_linalg::pack::PackedFormat;
37
38#[derive(Debug, Clone, new, Hash)]
39pub struct Conv {
40 pub pool_spec: PoolSpec,
41 pub kernel_fmt: KernelFormat,
42 pub group: usize,
43 pub q_params: Option<DatumType>,
48}
49
50impl Conv {
51 pub fn input_channels(&self) -> usize {
52 self.pool_spec.input_channels
53 }
54
55 pub fn output_channels(&self) -> usize {
56 self.pool_spec.output_channels
57 }
58
59 pub fn wire_kernel_as_g_o_ihw(
60 &self,
61 model: &mut TypedModel,
62 name: &str,
63 mut kernel: OutletId,
64 ) -> TractResult<TVec<OutletId>> {
65 let fact = model.outlet_fact(kernel)?;
66 if fact.datum_type.is_opaque() {
67 ensure!(self.kernel_fmt == KernelFormat::OIHW && fact.rank() == 0);
68 kernel = model.wire_node(
69 format!("{name}.prep_kernel.g"),
70 SplitGroupBlockQuant { group: self.group },
71 &[kernel],
72 )?[0];
73 kernel = model.wire_node(
74 format!("{name}.prep_kernel.ihw"),
75 BlockQuantIntoShape {
76 shape: tvec!(
77 self.output_channels() / self.group,
78 self.input_channels() / self.group
79 * self.pool_spec.kernel_shape.iter().product::<usize>(),
80 ),
81 },
82 &[kernel],
83 )?[0];
84 Ok(tvec!(kernel))
85 } else {
86 for (ix, op) in self
87 .kernel_fmt
88 .kernel_as_group_o_ihw_ops(&fact.shape, self.group)
89 .into_iter()
90 .enumerate()
91 {
92 kernel = model.wire_node(format!("{name}.prep_kernel.{ix}"), op, &[kernel])?[0];
93 }
94 Ok(tvec!(kernel))
95 }
96 }
97
98 fn wire_pack_g_o_ihw(
99 &self,
100 model: &mut TypedModel,
101 name: &str,
102 format: &dyn MMMInputFormat,
103 kernel: OutletId,
104 ) -> TractResult<OutletId> {
105 let fact = model.outlet_fact(kernel)?;
106 let wire = if fact.datum_type.is_opaque() {
107 let fact = model
108 .outlet_fact(kernel)?
109 .opaque_fact
110 .as_ref()
111 .and_then(|of| of.downcast_ref::<BlockQuantFact>())
112 .context("Only manage BlockQuant")?;
113 model.wire_node(
114 format!("{name}.prep_kernel.pack"),
115 OptSimpleMatMulPack {
116 packed_format: format
117 .downcast_ref::<PackedBlockQuantFormat>()
118 .context("Expect a block quant format")?
119 .clone(),
120 k: fact.k(),
121 m: fact.m(),
122 },
123 &[kernel],
124 )?
125 } else {
126 let format = format
127 .downcast_ref::<PackedFormat>()
128 .context("Expect regular packing for numeric weights")?;
129 model.wire_node(
130 format!("{name}.prep_kernel.pack"),
131 OptMatMulPack {
132 packers: vec![format.clone()],
133 k_axis: 2,
134 mn_axis: 1,
135 mode_picker: ModePicker::Single,
136 },
137 &[kernel],
138 )?
139 };
140 Ok(wire[0])
141 }
142
143 fn wire_bias_as_non_linear(
145 &self,
146 model: &mut TypedModel,
147 name: &str,
148 bias: OutletId,
149 c_group_axis: usize,
150 ) -> TractResult<(ProtoFusedSpec, OutletId)> {
151 use tract_linalg::BinOp::Add;
152 let fact = model.outlet_fact(bias)?;
153 if fact.shape.volume().is_one() {
154 Ok((ProtoFusedSpec::BinScalar(2, Add), bias))
155 } else {
156 let bias = AxisOp::wire_split_axis(
157 model,
158 format!("{name}.reformat_bias"),
159 bias,
160 0,
161 self.group,
162 )?[0];
163 let pfs =
164 ProtoFusedSpec::BinPerRow(2, Add, MapOutputAxisToInput(tvec!((c_group_axis, 0))));
165 Ok((pfs, bias))
166 }
167 }
168
169 pub unsafe fn wire_as_quant_im2col(
170 &self,
171 model: &mut TypedModel,
172 name: &str,
173 wires: &[OutletId],
174 ) -> TractResult<TVec<OutletId>> {
175 ensure!(self.q_params.is_some());
176 use crate::ops::matmul::quant as qmm;
177
178 let c_dt = self.q_params.unwrap();
179 let &[mut x, mut kernel, bias, mut x0, x_scale, mut k0, mut k_scale, y0, y_scale] = wires
180 else {
181 bail!("Wrong number of inputs")
182 };
183 wire_ensure_q8_flavour(model, name, &mut kernel, "k", &mut k0, i8::datum_type())?;
184 wire_ensure_q8_flavour(model, name, &mut x, "x", &mut x0, i8::datum_type())?;
185
186 let a_fact = model.outlet_fact(kernel)?.clone();
187 let b_fact = model.outlet_fact(x)?.clone();
188
189 let (_geo, m, k, n) = self.compute_geo(&b_fact)?;
190 let (mmm, packing) = self.choose_impl(&b_fact, &a_fact, m, k, &n)?;
191 let output_shape = self.pool_spec.output_shape(&b_fact.shape)?;
192
193 if !model.outlet_fact(k_scale)?.shape.volume().is_one() {
194 if !output_shape.fmt.c_is_last() {
197 k_scale = model.wire_node(
198 format!("{name}.a_scale_axis_fix"),
199 AxisOp::Add(1),
200 &[k_scale],
201 )?[0];
202 }
203 }
204
205 let abc_scale = qmm::combine_scales(model, name, k_scale, x_scale, y_scale)?;
206
207 let im2col = model.wire_node(
208 format!("{name}.im2col"),
209 Im2Col::new(
210 self.pool_spec.clone(),
211 self.group,
212 k,
213 &b_fact.shape,
214 mmm.clone(),
215 packing,
216 )?,
217 &[x, x0],
218 )?[0];
219
220 let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
221 let g_o_ihw_as_i32 =
222 model.wire_node(format!("{name}.kernel_as_i32"), cast(i32::datum_type()), &g_o_ihw)?;
223 let sum_ker_g_c_k = model.wire_node(
224 format!("{name}.sum_ker_g_c_k"),
225 Reduce::new(tvec!(2), ops::nn::Reducer::Sum),
226 &g_o_ihw_as_i32,
227 )?;
228 let sum_ker_a_g_c =
229 model.wire_node(format!("{name}.rm_k"), AxisOp::Rm(2), &sum_ker_g_c_k)?;
230 let sum_ker_n_g_c = model.wire_node(
232 format!("{name}.sum_ker_n_g_c.axis_0"),
233 AxisOp::Add(0),
234 &sum_ker_a_g_c,
235 )?;
236 let hw_position = if self.pool_spec.data_format.c_is_last() { 1 } else { 3 };
237 let sum_ker = model.wire_node(
238 format!("{name}.sum_ker_n_g_c"),
239 AxisOp::Add(hw_position),
240 &sum_ker_n_g_c,
241 )?;
242
243 ensure!(mmm.packings()[packing].1.downcast_ref::<PackedFormat>().is_some());
244 let mut sum_x = model.wire_node(
245 format!("{name}.sum_x"),
246 super::QSumB { dt: b_fact.datum_type, n, r: mmm.nr(), k },
247 &[im2col],
248 )?;
249 sum_x = model.wire_node(format!("{name}.add_c"), AxisOp::Add(2), &sum_x)?;
251 if self.pool_spec.data_format.c_is_last() {
252 sum_x =
253 model.wire_node(format!("{name}.transpose_sum_b"), AxisOp::Move(3, 1), &sum_x)?;
254 }
255
256 let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?;
257 let bias_name = &model.node(bias.node).name;
258 let bias =
259 model.wire_node(format!("{bias_name}.cast"), cast(mmm.internal_type()), &[bias])?[0];
260 let wire = self.wire_mm_weights_bias(
261 model,
262 name,
263 im2col,
264 g_o_ihw[0],
265 bias,
266 mmm,
267 packing,
268 i32::datum_type(),
269 mmm_output_shape.clone().into(),
270 k,
271 c_axis,
272 h_axis,
273 )?;
274
275 let wire = qmm::compensate_zero_points(
276 model,
277 name,
278 wire[0],
279 k.to_dim(),
280 k0,
281 x0,
282 sum_ker[0],
283 sum_x[0],
284 )?;
285
286 let wire = self.wire_remove_group(model, name, &[wire], &mmm_output_shape, c_axis)?;
287 let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
288 let wire = qmm::requant(model, name, wire[0], c_dt, abc_scale, y0)?;
289 Self::wire_geo_reshape(model, name, &[wire], &output_shape)
290 }
291
292 pub fn wire_remove_group<D: DimLike>(
293 &self,
294 model: &mut TypedModel,
295 name: &str,
296 wire: &[OutletId],
297 mmm_output_shape: &[D],
298 c_axis: usize,
299 ) -> TractResult<TVec<OutletId>> {
300 let m = &mmm_output_shape[c_axis];
301 let op = if self.group == 1 {
302 AxisOp::Rm(c_axis - 1)
303 } else {
304 AxisOp::Reshape(
305 c_axis - 1,
306 tvec!(self.group.to_dim(), m.to_dim()),
307 tvec!(m.to_dim() * self.group),
308 )
309 };
310 model.wire_node(format!("{name}.reshape_group"), op, wire)
311 }
312
313 pub unsafe fn wire_as_im2col_pair(
314 &self,
315 model: &mut TypedModel,
316 name: &str,
317 wire: &[OutletId],
318 ) -> TractResult<TVec<OutletId>> {
319 let &[x, w, bias] = wire else { bail!("Wrong number of inputs") };
320 let x_fact = model.outlet_fact(x)?.clone();
321 let w_fact = model.outlet_fact(w)?.clone();
322 let c_dt = crate::ops::matmul::output_type(x_fact.datum_type);
323
324 let (_, m, k, n) = self.compute_geo(&x_fact)?;
325 let (mmm, packing) = self.choose_impl(&x_fact, &w_fact, m, k, &n)?;
326 let geo_output_shape = self.pool_spec.output_shape(&x_fact.shape)?;
327 let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo_output_shape)?;
328
329 let padding =
330 model.add_const(format!("{name}.b0"), Tensor::zero_scalar_dt(x_fact.datum_type)?)?;
331
332 let mut wire: TVec<_> = wire.into();
333 wire[0] = model.wire_node(
334 format!("{name}.im2col"),
335 Im2Col::new(
336 self.pool_spec.clone(),
337 self.group,
338 k,
339 &x_fact.shape,
340 mmm.clone(),
341 packing,
342 )?,
343 &[wire[0], padding],
344 )?[0];
345
346 let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, wire[1])?;
347
348 let wire = self
349 .wire_mm_weights_bias(
350 model,
351 name,
352 wire[0],
353 g_o_ihw[0],
354 bias,
355 mmm,
356 packing,
357 c_dt,
358 mmm_output_shape.clone().into(),
359 k.to_usize().unwrap(),
360 c_axis,
361 h_axis,
362 )
363 .context("in wire_opt_matmul")?;
364
365 let wire = self.wire_remove_group(model, name, &wire, &mmm_output_shape, c_axis)?;
366 let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
367 Self::wire_geo_reshape(model, name, &wire, &geo_output_shape)
368 }
369
370 fn mmm_output_shape<D: DimLike>(
372 &self,
373 output_shape: &BaseDataShape<D, TVec<D>>,
374 ) -> TractResult<(TVec<D>, usize, usize)> {
375 let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
376 let shape: BaseDataShape<D, TVec<D>> = output_shape.fmt.with_n().from_n_c_hw(
377 output_shape.n().cloned().unwrap_or_else(|| 1.into()),
378 output_shape.c().clone(),
379 tvec!(geo_collapsed_out),
380 )?;
381 let mut mmm_output_shape: TVec<D> = shape.shape.clone();
382 let mut c_axis = shape.c_axis();
383 let mut h_axis = shape.h_axis();
384 mmm_output_shape[shape.c_axis()] = mmm_output_shape[c_axis].clone() / self.group;
385 mmm_output_shape.insert(c_axis, self.group.into());
386 if h_axis > c_axis {
387 h_axis += 1;
388 }
389 c_axis += 1;
390 Ok((mmm_output_shape, c_axis, h_axis))
391 }
392
393 fn wire_rm_n_if_needed(
394 &self,
395 model: &mut TypedModel,
396 name: &str,
397 wire: &[OutletId],
398 ) -> TractResult<TVec<OutletId>> {
399 if self.pool_spec.data_format.has_n() {
400 Ok(wire.into())
401 } else {
402 model.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), wire)
403 }
404 }
405
406 fn wire_geo_reshape<D: DimLike>(
407 model: &mut TypedModel,
408 name: &str,
409 wire: &[OutletId],
410 output_shape: &BaseDataShape<D, TVec<D>>,
411 ) -> TractResult<TVec<OutletId>> {
412 let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
413 model
414 .wire_node(
415 name,
416 AxisOp::Reshape(
417 output_shape.h_axis(),
418 tvec!(geo_collapsed_out.to_dim()),
419 output_shape.hw_dims().iter().map(|d| d.to_dim()).collect(),
420 ),
421 wire,
422 )
423 .context("in wire_geo_reshape")
424 }
425
426 pub unsafe fn wire_as_lazy_im2col(
427 &self,
428 model: &mut TypedModel,
429 name: &str,
430 wire: &[OutletId],
431 ) -> TractResult<TVec<OutletId>> {
432 let &[mut x, kernel, bias] = wire else { bail!("Wrong number of inputs") };
433 let mut x_fact = model.outlet_fact(x)?.clone();
434 let w_fact = model.outlet_fact(kernel)?.clone();
435 let (geo, m, k, n) = self.compute_geo(&x_fact)?;
436 let (mmm, packing) = self.choose_impl(&x_fact, &w_fact, m, k, &n)?;
437 debug!("{name} as lazy_im2col: m={m} k={k} n={n} {mmm:?}");
438 let input_shape = x_fact.shape.as_concrete().unwrap().to_vec();
439 let mut geo = geo.to_concrete(&input_shape)?.into_owned();
440 let mut input_shape: DataShape = self.pool_spec.data_format.shape(input_shape.into())?;
441 let padding = self.pool_spec.computed_padding(input_shape.hw_dims());
442 if padding.iter().any(|axis| axis.pad_before != 0 || axis.pad_after != 0) {
443 let mut pads = vec![(0, 0); x_fact.rank()];
444 for (ix, ax) in padding.iter().enumerate() {
445 pads[input_shape.h_axis() + ix] = (ax.pad_before, ax.pad_after);
446 }
447 let op = crate::ops::array::Pad {
448 mode: crate::ops::array::PadMode::Constant(
449 Tensor::zero_scalar_dt(x_fact.datum_type)?.into_arc_tensor(),
450 ),
451 pads,
452 };
453 x = model.wire_node(format!("{name}.pad"), op, &[x])?[0];
454 let valid_pool_spec = PoolSpec { padding: Valid, ..self.pool_spec.clone() };
455 x_fact = model.outlet_fact(x)?.clone();
456 let concrete_shape = x_fact.shape.as_concrete().unwrap();
457 input_shape = valid_pool_spec.data_format.shape(concrete_shape.into())?;
458 geo = valid_pool_spec
459 .compute_geo(&x_fact.shape)?
460 .to_concrete(concrete_shape)?
461 .into_owned();
462 }
463 let c_dt = crate::ops::matmul::output_type(x_fact.datum_type);
464 let c_stride = input_shape.c_stride();
465 let size_of_b = x_fact.datum_type.size_of() as isize;
466 let n_byte_offsets: Vec<isize> =
467 geo.patch.centers_offsets().into_iter().map(|x| x * size_of_b).collect();
468 let k_byte_offsets: Vec<isize> = (0..self.input_channels())
469 .flat_map(|ici| {
470 geo.patch
471 .standard_layout_data_field
472 .iter()
473 .map(move |x| (x + (ici * c_stride) as isize) * size_of_b)
474 })
475 .collect();
476 let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo.output_shape)?;
477 let packer = mmm.packings()[packing]
478 .1
479 .downcast_ref::<PackedFormat>()
480 .with_context(|| {
481 format_err!(
482 "Quand Im2Col expects regular packed format, got {:?}",
483 mmm.packings()[packing].1
484 )
485 })?
486 .clone();
487 let params = LazyIm2colParams { packer, n_byte_offsets, k_byte_offsets };
488 let x = model.wire_node(
489 format!("{name}.lazyIm2col"),
490 LazyIm2Col { params: Arc::new(params) },
491 &[x],
492 )?[0];
493
494 let kernel = self.wire_kernel_as_g_o_ihw(model, name, kernel)?[0];
495 let wire = self.wire_mm_weights_bias(
496 model,
497 name,
498 x,
499 kernel,
500 bias,
501 mmm,
502 packing,
503 c_dt,
504 mmm_output_shape.clone().into(),
505 k,
506 c_axis,
507 h_axis,
508 )?;
509
510 let wire = self.wire_remove_group(model, name, &wire, &mmm_output_shape, c_axis)?;
511 let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
512 Self::wire_geo_reshape(model, name, &wire, &geo.output_shape)
513 }
514
515 #[allow(clippy::type_complexity)]
516 fn compute_geo(
517 &self,
518 input_fact: &TypedFact,
519 ) -> TractResult<(PoolGeometry, usize, usize, TDim)> {
520 let geo = self.pool_spec.compute_geo(&input_fact.shape)?;
521
522 trace!("output channels: {:?}", self.output_channels());
523 let m = self.output_channels() / self.group;
524 let k = self.input_channels() * self.pool_spec.kernel_shape.iter().product::<usize>()
525 / self.group;
526 let n: TDim =
527 self.pool_spec.output_shape(&input_fact.shape)?.hw_dims().iter().cloned().product();
528 Ok((geo, m, k, n))
529 }
530
531 fn choose_impl(
532 &self,
533 input_fact: &TypedFact,
534 weight_fact: &TypedFact,
535 m: usize,
536 k: usize,
537 n: &TDim,
538 ) -> TractResult<(Box<dyn MatMatMul>, usize)> {
539 let w_dt = weight_fact.datum_type;
540 let x_dt = input_fact.datum_type;
541
542 let acc = if x_dt.is_float() { x_dt } else { i32::datum_type() };
543 if w_dt.is_opaque() {
544 let bqf = weight_fact
545 .opaque_fact
546 .as_ref()
547 .and_then(|of| of.downcast_ref::<BlockQuantFact>())
548 .unwrap();
549 let weight_type = WeightType::BlockQuant(bqf.format.clone());
550 tract_linalg::ops()
551 .mmm_impls()
552 .iter()
553 .filter(|mmm| mmm.internal_type() == acc)
554 .flat_map(|mmm| {
555 mmm.packings().iter().enumerate().map(move |(ix, p)| (mmm, ix, &p.0, &p.1))
556 })
557 .filter(|(_, _, pa, pb)| {
558 pb.precursor() == x_dt.into() && pa.precursor() == weight_type
559 })
560 .map(|(mmm, p, _, _)| (mmm.clone(), p))
561 .min_by_key(|(mmm, _)| {
562 mmm.quality().cost() as isize * 1000 - (mmm.mr() * mmm.nr()) as isize
563 })
564 .context("Not matmu found")
565 } else {
566 let mmm = tract_linalg::ops()
567 .mmm(acc, Some(m), Some(k), n.to_usize().ok())
568 .context("No matmul found")?;
569 let packing = mmm
570 .packings()
571 .iter()
572 .position(|p| {
573 p.0.precursor() == w_dt.unquantized().into()
574 && p.1.precursor() == x_dt.unquantized().into()
575 })
576 .context("No packing found")?;
577 Ok((mmm, packing))
578 }
579 }
580
581 #[allow(clippy::too_many_arguments)]
582 fn wire_mm_weights_bias(
583 &self,
584 model: &mut TypedModel,
585 name: &str,
586 input: OutletId,
587 g_o_ihw: OutletId,
588 bias: OutletId,
589 mmm: Box<dyn MatMatMul>,
590 packing: usize,
591 c_datum_type: DatumType,
592 mmm_output_shape: ShapeFact,
593 k: usize,
594 c_m_axis: usize,
595 c_n_axis: usize,
596 ) -> TractResult<TVec<OutletId>> {
597 ensure!(model.outlet_fact(bias)?.datum_type == mmm.internal_type());
598 let a_pack = &mmm.packings()[packing].0;
599 let packed_ker = self
600 .wire_pack_g_o_ihw(model, name, &**a_pack, g_o_ihw)
601 .context("in kernel_as_packed_as")?;
602 let (mut c_to_a_axis_mapping, mut c_to_b_axis_mapping) = (tvec!(), tvec!());
603
604 c_to_a_axis_mapping.push((c_m_axis - 1, 0)); c_to_b_axis_mapping.push((0, 0)); c_to_b_axis_mapping.push((c_m_axis - 1, 1)); let geo = AddMatMulGeometry {
609 k: k.to_dim(),
610 c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
611 c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
612 };
613 let mut ops: Vec<ProtoFusedSpec> =
614 vec![ProtoFusedSpec::AddMatMul { geo, a: 1, b: 0, packings: vec![(packing, None)] }];
615 let mut wires: TVec<OutletId> = tvec!(input, packed_ker);
616 let bias_fact = model.outlet_fact(bias)?;
617 if bias_fact.konst.is_none() || !bias_fact.konst.as_ref().unwrap().is_all_zero()? {
618 let (fused, bias) = self.wire_bias_as_non_linear(model, name, bias, c_m_axis - 1)?;
619 wires.push(bias);
620 ops.push(fused);
621 }
622 ops.push(ProtoFusedSpec::Store(vec![unsafe {
623 mmm.c_view(Some(c_m_axis), Some(c_n_axis))
624 }]));
625 model.wire_node(
626 format!("{name}.matmatmul"),
627 OptMatMul::new(
628 vec![mmm],
629 ModePicker::Single,
630 c_datum_type.fact(mmm_output_shape),
631 Some(c_m_axis),
632 Some(c_n_axis),
633 ops,
634 packing == 0 && self.group == 1,
635 )?,
636 &wires,
637 )
638 }
639
640 pub fn wire_as_depth_wise(
641 &self,
642 model: &mut TypedModel,
643 name: &str,
644 wire: &[OutletId],
645 ) -> TractResult<OutletId> {
646 let &[x, kernel, mut bias] = wire else { bail!("Wrong number of inputs") };
647 let x_fact = model.outlet_fact(x)?.clone();
648 let x_shape = x_fact.shape.as_concrete().unwrap();
649 let ConcretePoolGeometry { input_shape, patch, output_shape } =
650 self.pool_spec.compute_geo(&x_fact.shape)?.to_concrete(x_shape)?.into_owned();
651 let kernel = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
652 let c_axis = self.pool_spec.data_format.shape(x_shape)?.c_axis();
653 bias = wire_reshape_bias_for_bin(
654 model,
655 name,
656 bias,
657 x_fact.rank(),
658 c_axis,
659 self.output_channels(),
660 )?[0];
661 let op = DepthWise::new(patch, input_shape, output_shape);
662 Ok(model.wire_node(name, op, &[x, kernel[0], bias])?[0])
663 }
664
665 fn declutter_stride_slice_to_downsample(
666 &self,
667 model: &TypedModel,
668 node: &TypedNode,
669 ) -> TractResult<Option<TypedModelPatch>> {
670 let spatial_rank = self.pool_spec.rank();
671 if let Some(axis) = (0..spatial_rank).find(|&ax| {
672 self.pool_spec.stride(ax) > 1
673 && self.pool_spec.padding.valid_dim(ax, self.pool_spec.stride(ax) == 1)
674 && (self.pool_spec.kernel_shape[ax] == 1
675 || self.pool_spec.dilation(ax) % self.pool_spec.stride(ax) == 0)
676 }) {
677 let input_fact = model.outlet_fact(node.inputs[0])?;
678 let downsample_factor = self.pool_spec.stride(axis);
679 let mut new_op = self.clone();
680 if new_op.pool_spec.dilation(axis) > 1 {
681 new_op.pool_spec.dilations.as_mut().unwrap()[axis] /= downsample_factor;
682 }
683 new_op.pool_spec.strides.as_mut().unwrap()[axis] /= downsample_factor;
684 let mut patch = TypedModelPatch::default();
685 let mut taps = patch.taps(model, &node.inputs)?;
686 let shape = self.pool_spec.data_format.shape(&input_fact.shape)?;
687 taps[0] = patch.wire_node(
688 format!("{}.downsample.{}", node.name, axis),
689 crate::ops::Downsample::new(axis + shape.h_axis(), downsample_factor as isize, 0),
690 &[taps[0]],
691 )?[0];
692 let id = patch.wire_node(&*node.name, new_op, &taps)?[0];
693 patch.shunt_outside(model, OutletId::new(node.id, 0), id)?;
694 return Ok(Some(patch));
695 }
696 Ok(None)
697 }
698
699 fn declutter_as_einsum(
700 &self,
701 model: &TypedModel,
702 node: &TypedNode,
703 ) -> TractResult<Option<TypedModelPatch>> {
704 let (input_facts, output_facts) = model.node_facts(node.id)?;
705 let full_input_shape = input_facts[0].shape.to_tvec();
706 let input_shape = self.pool_spec.data_format.shape(&full_input_shape)?;
707 if self.group == 1
708 && self.pool_spec.strides().iter().all(|s| *s == 1)
709 && self.pool_spec.dilations().iter().all(|d| *d == 1)
710 && self.pool_spec.kernel_shape.iter().product::<usize>() == 1
711 && self
712 .pool_spec
713 .computed_padding(input_shape.hw_dims())
714 .iter()
715 .all(|pad| pad.pad_after.is_zero() && pad.pad_before.is_zero())
716 {
717 let mut axes = self.axes_mapping(&input_facts, &output_facts)?;
718 let mut patch = TypedModelPatch::new("declutter_as_einsum");
719 let mut taps = patch.taps(model, &node.inputs)?;
720 let name = &node.name;
721 let co = self.output_channels();
722 taps[1] =
723 self.wire_kernel_as_g_o_ihw(&mut patch, &format!("{name}.filters"), taps[1])?[0];
724 taps[1] =
725 patch.wire_node(format!("{name}.filters_as_co_ci"), AxisOp::Rm(0), &[taps[1]])?[0];
726
727 while axes.rank(InOut::In(1)) > 0 {
728 axes = axes.remove_axis_occurency(InOut::In(1), 0)?;
729 }
730 axes = axes
731 .with_extra_axis_occurency('O', InOut::In(1), 0)?
732 .with_extra_axis_occurency('I', InOut::In(1), 1)?;
733
734 let bias_fact = input_facts[2];
735 let wire = if self.q_params.is_some() {
736 if bias_fact.rank() == 1 {
737 axes = axes.linking('O', (InOut::In(2), 0))?;
738 }
739 let op = EinSum { axes, operating_dt: i32::datum_type(), q_params: self.q_params };
740 patch.wire_node(format!("{name}.einsum"), op, &taps)?[0]
741 } else {
742 axes = axes.remove_slot(InOut::In(2))?;
743 let op = EinSum { axes, operating_dt: input_facts[0].datum_type, q_params: None };
744 let mut wire = patch.wire_node(format!("{name}.einsum"), op, &taps[0..2])?[0];
745
746 if !bias_fact.konst.as_ref().map(|f| f.is_zero()).transpose()?.unwrap_or(false) {
747 let bias_current_shape =
748 if bias_fact.rank() == 0 { tvec!() } else { tvec!(co.to_dim()) };
749 let mut bias_shape = tvec!(1.to_dim(); input_shape.rank());
750 if bias_fact.rank() > 0 {
751 bias_shape[input_shape.c_axis()] = co.to_dim();
752 }
753 let b = patch.wire_node(
754 format!("{name}.bias.reshape"),
755 AxisOp::Reshape(0, bias_current_shape, bias_shape),
756 &[taps[2]],
757 )?[0];
758 wire = patch.wire_node(
759 format!("{name}.bias"),
760 crate::ops::math::add(),
761 &[wire, b],
762 )?[0];
763 }
764 wire
765 };
766 patch.node_mut(wire.node).name = node.name.to_string();
767 patch.shunt_outside(model, node.id.into(), wire)?;
768 return Ok(Some(patch));
769 }
770 Ok(None)
771 }
772
773 fn declutter_precursor_padding(
774 &self,
775 model: &TypedModel,
776 node: &TypedNode,
777 ) -> TractResult<Option<TypedModelPatch>> {
778 if matches!(self.pool_spec.padding, ExplicitOnnxPool(_, _, _) | SameLower | SameUpper) {
779 return Ok(None);
780 }
781 let prec = model.node(node.inputs[0].node);
782 let pad = if let Some(pad) = prec.op_as::<Pad>() { pad } else { return Ok(None) };
783 let value = if let PadMode::Constant(c) = &pad.mode {
784 c
785 } else {
786 return Ok(None);
787 };
788 let shape = self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
789 if !value.is_zero()?
790 || (self.pool_spec.data_format.has_n() && pad.pads[0] != (0, 0))
791 || pad.pads[shape.c_axis()] != (0, 0)
792 {
793 return Ok(None);
794 }
795 let mut before: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.0).collect();
796 let mut after: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.1).collect();
797 if let Explicit(bef, aft) = &self.pool_spec.padding {
798 izip!(&mut before, bef).for_each(|(pad, cv)| *pad += cv);
799 izip!(&mut after, aft).for_each(|(pad, cv)| *pad += cv);
800 }
801 let padding = Explicit(before, after);
802 let mut new = self.clone();
803 new.pool_spec.padding = padding;
804 let mut patch = TypedModelPatch::default();
805 let mut wire = patch.taps(model, &node.inputs)?;
806 wire[0] = patch.tap_model(model, prec.inputs[0])?;
807 let wire = patch.wire_node(&node.name, new, &wire)?;
808 patch.shunt_outside(model, node.id.into(), wire[0])?;
809 Ok(Some(patch))
810 }
811
812 fn declutter_channel_arithmetic_succ(
813 &self,
814 model: &TypedModel,
815 node: &TypedNode,
816 ) -> TractResult<Option<TypedModelPatch>> {
817 if self.q_params.is_some() || self.group != 1 {
818 return Ok(None);
819 }
820 let &[succ_outlet] = &*node.outputs[0].successors else { return Ok(None) };
821 let succ = model.node(succ_outlet.node);
822 let Some(bin) = succ.op_as::<TypedBinOp>() else { return Ok(None) };
823 let other_input = succ.inputs[1 - succ_outlet.slot];
824 let axes_mapping = model.node_axes_mapping(succ.id)?;
825 let input_shape =
826 self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
827 let conv_c_axis = input_shape.c_axis();
828 if axes_mapping.axis((InOut::In(succ_outlet.slot), conv_c_axis))?.inputs
829 [1 - succ_outlet.slot]
830 .len()
831 != 1
832 {
833 return Ok(None);
834 };
835 let mut other_expected_shape = tvec!(1.to_dim(); input_shape.rank());
836 other_expected_shape[conv_c_axis] = self.output_channels().to_dim();
837 if *other_expected_shape != *model.outlet_fact(other_input)?.shape {
838 return Ok(None);
839 }
840
841 let mut patch = TypedModelPatch::default();
842 let [input, mut kernel, mut bias] = *patch.taps(model, &node.inputs)? else {
843 panic!("Expect three inputs");
844 };
845 let name = &node.name;
846 let succ_name = &succ.name;
847
848 let operand = patch.tap_model(model, other_input)?;
849
850 let renamed_bias = format!("{name}.{succ_name}.bias");
851 let renamed_kernel = format!("{name}.{succ_name}.kernel");
852 bias = wire_reshape_bias_for_bin(
853 &mut patch,
854 format!("{renamed_bias}.reshape"),
855 bias,
856 1,
857 0,
858 self.output_channels(),
859 )?[0];
860
861 let operand = wire_reshape_bias_for_bin(
862 &mut patch,
863 format!("{renamed_bias}.reshape_operand"),
864 operand,
865 1,
866 0,
867 self.output_channels(),
868 )?[0];
869
870 let operand_fact = patch.outlet_fact(operand)?.shape.to_tvec();
871 let kernel_fact = patch.outlet_fact(kernel)?;
872 let mut operand_shape_for_kernel = tvec!(1.to_dim(); 2 + input_shape.hw_rank());
873 operand_shape_for_kernel[self.kernel_fmt.o_axis(&kernel_fact.shape)] =
874 self.output_channels().to_dim();
875 let operand_for_kernel = patch.wire_node(
876 format!("{renamed_kernel}.reshape_operand"),
877 AxisOp::Reshape(0, operand_fact, operand_shape_for_kernel),
878 &[operand],
879 )?[0];
880
881 if bin.0.is::<Sub>() && succ_outlet.slot == 0 {
882 bias = patch.wire_node(&renamed_bias, sub(), &[bias, operand])?[0];
883 } else if bin.0.is::<Sub>() {
884 bias = patch.wire_node(&renamed_bias, sub(), &[operand, bias])?[0];
885 } else if bin.0.is::<Div>() && succ_outlet.slot == 0 {
886 bias = patch.wire_node(&renamed_bias, div(), &[bias, operand])?[0];
887 kernel = patch.wire_node(&renamed_kernel, div(), &[kernel, operand_for_kernel])?[0];
888 } else if bin.0.is::<Div>() {
889 bias = patch.wire_node(&renamed_bias, div(), &[operand, bias])?[0];
890 kernel = patch.wire_node(&renamed_kernel, div(), &[operand_for_kernel, kernel])?[0];
891 } else if bin.0.is::<Add>() {
892 bias = patch.wire_node(&renamed_bias, add(), &[bias, operand])?[0];
893 } else if bin.0.is::<Mul>() {
894 bias = patch.wire_node(&renamed_bias, mul(), &[bias, operand])?[0];
895 kernel = patch.wire_node(&renamed_kernel, mul(), &[kernel, operand_for_kernel])?[0];
896 } else {
897 return Ok(None);
898 };
899 let wire = patch.wire_node(&node.name, self.clone(), &[input, kernel, bias])?[0];
900 patch.shunt_outside(model, succ_outlet.node.into(), wire)?;
901 Ok(Some(patch))
902 }
903}
904
905impl Op for Conv {
906 fn name(&self) -> StaticName {
907 "Conv".into()
908 }
909
910 fn info(&self) -> TractResult<Vec<String>> {
911 let mut info = self.pool_spec.info();
912 info.push(format!("Kernel {:?} (groups:{})", self.kernel_fmt, self.group));
913 Ok(info)
914 }
915
916 fn validation(&self) -> Validation {
917 Validation::Rounding
918 }
919
920 op_as_typed_op!();
921}
922
923impl EvalOp for Conv {
924 fn is_stateless(&self) -> bool {
925 true
926 }
927
928 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
929 let mut model = TypedModel::default();
930 let wire: TVec<OutletId> = inputs
931 .iter()
932 .enumerate()
933 .map(|(ix, v)| model.add_source(format!("source.{ix}"), v.datum_type().fact(v.shape())))
934 .collect::<TractResult<_>>()?;
935 let wire = unsafe {
936 if self.q_params.is_some() {
937 self.wire_as_quant_im2col(&mut model, "im2col-adhoc", &wire)?
938 } else {
939 self.wire_as_im2col_pair(&mut model, "im2col-adhoc", &wire)?
940 }
941 };
942 model.set_output_outlets(&wire)?;
943 model.into_runnable()?.run(inputs)
944 }
945}
946
947impl TypedOp for Conv {
948 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
949 ensure!(self.q_params.is_some() || inputs[0].datum_type.is_float());
950 let q_inputs = if self.q_params.is_some() { 6 } else { 0 };
951 ensure!(inputs[1].datum_type.is_number() || self.kernel_fmt == KernelFormat::OIHW);
952 if inputs.len() != 3 + q_inputs {
953 bail!("Wrong number of inputs: expected {} got {}", 3 + q_inputs, inputs.len());
954 }
955 if self.q_params.is_some() {
956 ensure!(inputs[2].datum_type == i32::datum_type());
957 ensure!(inputs[3].datum_type == i32::datum_type());
958 ensure!(inputs[4].datum_type.is_float());
959 ensure!(inputs[5].datum_type == i32::datum_type());
960 ensure!(inputs[6].datum_type.is_float());
961 ensure!(inputs[7].datum_type == i32::datum_type());
962 ensure!(inputs[8].datum_type.is_float());
963 }
964 let weight_shape = block_quant_aware_weight_shape(inputs[1])?;
965 ensure!(self.pool_spec.rank() + 2 == weight_shape.len());
966 if self.pool_spec.data_format.shape(&*inputs[0].shape)?.c()
967 != &self.input_channels().to_dim()
968 {
969 bail!(
970 "Inconsistent convolution: input is {:?}, but kernel expects {} input channels.\n{:?}",
971 inputs[0],
972 self.input_channels(),
973 self
974 );
975 }
976 if let ExplicitOnnxPool(bef, after, _) | Explicit(bef, after) = &self.pool_spec.padding {
977 anyhow::ensure!(bef.len() == self.pool_spec.rank());
978 anyhow::ensure!(after.len() == self.pool_spec.rank());
979 }
980 ensure!(
981 inputs[2].rank() == 0
982 || (inputs[2].rank() == 1
983 && inputs[2].shape.volume() == self.output_channels().to_dim()),
984 "Bias should be scalar or a vector with one value per output channel. Output channels is {}, bias is {:?}",
985 self.output_channels(),
986 inputs[2]
987 );
988 let mut fact = self.pool_spec.output_facts(inputs)?.remove(0);
989 if let Some(dt) = self.q_params {
990 fact.datum_type = dt;
991 } else {
992 ensure!(
993 inputs[1].datum_type.is_opaque() || inputs[0].datum_type == inputs[1].datum_type,
994 "Convolution input, weights and bias must have the same type, got {inputs:?}",
995 )
996 }
997 Ok(tvec!(fact))
998 }
999
1000 fn axes_mapping(
1001 &self,
1002 inputs: &[&TypedFact],
1003 outputs: &[&TypedFact],
1004 ) -> TractResult<AxesMapping> {
1005 let fact = &inputs[0];
1006 let shape = self.pool_spec.data_format.shape(&fact.shape)?;
1007 let mut axes = AxesMapping::disconnected(inputs, outputs)?
1008 .renaming((InOut::In(0), shape.c_axis()), 'I')?
1009 .renaming((InOut::Out(0), shape.c_axis()), 'O')?;
1010 if let Some(n_axis) = shape.n_axis() {
1011 axes = axes
1012 .renaming((InOut::In(0), n_axis), 'N')?
1013 .linking('N', (InOut::Out(0), n_axis))?;
1014 }
1015 let h_axis = shape.h_axis();
1016 let geo = "HWXYZ".chars().chain('a'..);
1017 let kernel_spatial_shape = &self.pool_spec.kernel_shape;
1018 let padding = self.pool_spec.computed_padding(shape.hw_dims());
1019 for ((ix, &dim), repr) in kernel_spatial_shape.iter().enumerate().zip(geo) {
1020 if dim == 1
1021 && self.pool_spec.dilation(ix) == 1
1022 && self.pool_spec.stride(ix) == 1
1023 && padding[ix].pad_before.is_zero()
1024 && padding[ix].pad_after.is_zero()
1025 {
1026 axes = axes
1027 .renaming((InOut::In(0), ix + h_axis), repr)?
1028 .linking(repr, (InOut::Out(0), ix + h_axis))?;
1029 }
1030 }
1031 if self.q_params.is_some() {
1032 for (qp_ix, qp) in inputs.iter().enumerate().skip(3) {
1033 if qp.rank() == 1 {
1034 axes = match qp_ix {
1035 3 | 4 => axes.linking('I', (InOut::In(qp_ix), 0))?,
1036 5 | 6 => axes.linking('O', (InOut::In(qp_ix), 0))?,
1037 7 | 8 => axes.linking('O', (InOut::In(qp_ix), 0))?,
1038 _ => unreachable!(),
1039 };
1040 }
1041 }
1042 }
1043 Ok(axes)
1044 }
1045
1046 fn declutter(
1047 &self,
1048 model: &TypedModel,
1049 node: &TypedNode,
1050 ) -> TractResult<Option<TypedModelPatch>> {
1051 macro_rules! pass {
1052 ($func:ident) => {
1053 if let Some(mut r) = self.$func(model, node).context(stringify!($func))? {
1054 trace!(stringify!($func));
1055 r.push_context(stringify!($func));
1056 return Ok(Some(r));
1057 }
1058 };
1059 }
1060 pass!(declutter_stride_slice_to_downsample);
1061 pass!(declutter_as_einsum);
1062 pass!(declutter_channel_arithmetic_succ);
1063 pass!(declutter_precursor_padding);
1064 Ok(None)
1065 }
1066
1067 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
1068 let shape = self.pool_spec.data_format.shape(inputs[0].shape.to_tvec())?;
1069 let kernel_spatial_shape = &self.pool_spec.kernel_shape;
1070 let output_dims = self.pool_spec.padding.compute(
1071 shape.hw_dims(),
1072 kernel_spatial_shape,
1073 &self
1074 .pool_spec
1075 .dilations
1076 .clone()
1077 .unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
1078 &self.pool_spec.strides.clone().unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
1079 );
1080 let n_output_points: TDim =
1081 output_dims.iter().map(|d| d.convoluted.clone()).product::<TDim>();
1082 let n_output_channels = self.output_channels().to_dim();
1083 let kernel_surface = kernel_spatial_shape.iter().product::<usize>().to_dim();
1084 let one = 1.to_dim();
1085 Ok(tvec!((
1086 Cost::FMA(inputs[0].datum_type),
1087 shape.n().cloned().unwrap_or(one)
1088 * shape.c()
1089 * n_output_channels
1090 * n_output_points
1091 * kernel_surface
1092 / self.group
1093 )))
1094 }
1095
1096 fn change_axes(
1097 &self,
1098 model: &TypedModel,
1099 node: &TypedNode,
1100 io: InOut,
1101 change: &AxisOp,
1102 ) -> TractResult<Option<AxisChangeConsequence>> {
1103 if io == InOut::In(1) {
1104 return Ok(None);
1105 }
1106 if io == InOut::In(2) {
1107 if let &AxisOp::Rm(_) = change {
1108 return Ok(Some(AxisChangeConsequence {
1109 substitute_op: Some(Box::new(self.clone())),
1110 wire_changes: tvec!(),
1111 }));
1112 }
1113 }
1114 let full_input_shape = model.outlet_fact(node.inputs[0])?.shape.to_tvec();
1115 let shape = self.pool_spec.data_format.shape(full_input_shape.clone())?;
1116 if let Some(n) = shape.n_axis() {
1118 assert_eq!(n, 0);
1119 if change == &AxisOp::Rm(n) {
1120 let op = Conv { pool_spec: self.pool_spec.dispose_n_axis(), ..self.clone() };
1121 return Ok(Some(AxisChangeConsequence {
1122 substitute_op: Some(Box::new(op)),
1123 wire_changes: tvec!(
1124 (InOut::In(0), change.clone()),
1125 (InOut::Out(0), change.clone())
1126 ),
1127 }));
1128 }
1129 if change.transform_axis(n).map(|axis| axis > 0).unwrap_or(true) {
1130 return Ok(None);
1131 }
1132 }
1133 let (new_format, axis_move) = match self.pool_spec.data_format {
1135 DataFormat::NCHW => {
1136 (DataFormat::NHWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
1137 }
1138 DataFormat::CHW => {
1139 (DataFormat::HWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
1140 }
1141 DataFormat::NHWC => (DataFormat::NCHW, AxisOp::Move(shape.c_axis(), 1)),
1142 DataFormat::HWC => (DataFormat::CHW, AxisOp::Move(shape.c_axis(), 0)),
1143 };
1144 if *change == axis_move {
1145 let mut new_op = self.clone();
1146 new_op.pool_spec.data_format = new_format;
1147 return Ok(Some(AxisChangeConsequence {
1148 substitute_op: Some(Box::new(new_op)),
1149 wire_changes: tvec!(
1150 (InOut::In(0), change.clone()),
1151 (InOut::Out(0), change.clone())
1152 ),
1153 }));
1154 }
1155 if model.node_input_facts(node.id)?[1].datum_type.is_opaque() {
1157 return Ok(None);
1158 }
1159 use AxisOp::*;
1160 let h_axis = shape.h_axis();
1161 let hw_axes = shape.hw_axes();
1162 let kh_axis = self.kernel_fmt.h_axis();
1163 let (geo_adjusted, kernel_adjusted) = match change {
1164 Rm(a)
1165 if hw_axes.contains(a)
1166 && hw_axes.len() > 1
1167 && self.pool_spec.dilation(a - h_axis) == 1
1168 && self.pool_spec.stride(a - h_axis) == 1
1169 && self.pool_spec.kernel_shape[a - h_axis] == 1 =>
1170 {
1171 let geo_axis = a - h_axis;
1172 (Rm(geo_axis), Rm(kh_axis + geo_axis))
1173 }
1174 Add(a) if hw_axes.contains(a) => (Add(a - h_axis), Add(a - h_axis + kh_axis)),
1175 Move(f, t) if hw_axes.contains(f) && hw_axes.contains(t) => {
1176 (Move(f - h_axis, t - h_axis), Move(f - h_axis + kh_axis, t - h_axis + kh_axis))
1177 }
1178 _ => return Ok(None),
1179 };
1180 let pool_spec = self.pool_spec.change_geo_axes(&geo_adjusted)?;
1181 let new_op = Conv { pool_spec, ..self.clone() };
1182 Ok(Some(AxisChangeConsequence {
1183 substitute_op: Some(Box::new(new_op)),
1184 wire_changes: tvec!(
1185 (InOut::In(0), change.clone()),
1186 (InOut::In(1), kernel_adjusted),
1187 (InOut::Out(0), change.clone())
1188 ),
1189 }))
1190 }
1191
1192 fn codegen(
1193 &self,
1194 model: &TypedModel,
1195 node: &TypedNode,
1196 ) -> TractResult<Option<TypedModelPatch>> {
1197 let input_fact = model.outlet_fact(node.inputs[0])?;
1198 unsafe {
1199 if self.q_params.is_some() {
1200 let mut patch = TypedModelPatch::default();
1201 let inputs = patch.taps(model, &node.inputs)?;
1202 let wire = self
1203 .wire_as_quant_im2col(&mut patch, &node.name, &inputs)
1204 .context("in wire_as_quant_im2col")?;
1205 patch.shunt_outside(model, node.id.into(), wire[0])?;
1206 patch.obliterate(node.id)?;
1207 Ok(Some(patch.with_context("quantized-codegen")))
1208 } else if input_fact
1209 .shape
1210 .as_concrete()
1211 .map(|s| {
1212 should_use_lazy(
1213 &self.pool_spec.data_format.shape(s.into()).unwrap(),
1214 &self.pool_spec,
1215 self.group,
1216 )
1217 })
1218 .unwrap_or(false)
1219 {
1220 let mut patch = TypedModelPatch::new("wire_as_lazy_im2col");
1221 let inputs = patch.taps(model, &node.inputs)?;
1222 let wire = self
1223 .wire_as_lazy_im2col(&mut patch, &node.name, &inputs)
1224 .context("wire_as_lazy_im2col")?[0];
1225 patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1226 patch.obliterate(node.id)?;
1227 Ok(Some(patch))
1228 } else if self.group != 1
1229 && self.group == self.output_channels()
1230 && self.group == self.input_channels()
1231 && input_fact.shape.as_concrete().is_some()
1232 {
1233 let mut patch = TypedModelPatch::default();
1234 let inputs = patch.taps(model, &node.inputs)?;
1235 let wire = self
1236 .wire_as_depth_wise(&mut patch, &node.name, &inputs)
1237 .context("wire_as_depth_wise")?;
1238 patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1239 patch.obliterate(node.id)?;
1240 Ok(Some(patch))
1241 } else {
1242 let mut patch = TypedModelPatch::default();
1243 let inputs = patch.taps(model, &node.inputs)?;
1244 let wire = self
1245 .wire_as_im2col_pair(&mut patch, &node.name, &inputs)
1246 .context("in wire_as_im2col_pair")?[0];
1247 patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1248 patch.obliterate(node.id)?;
1249 Ok(Some(patch))
1250 }
1251 }
1252 }
1253
1254 as_op!();
1255}
1256
1257fn should_use_lazy(input_shape: &DataShape, pool_spec: &PoolSpec, group: usize) -> bool {
1258 input_shape.n().unwrap_or(&1) == &1
1259 && group == 1
1260 && pool_spec.kernel_shape.iter().product::<usize>() > 5
1261}
1262
1263#[allow(non_snake_case)]
1264#[cfg(test)]
1265mod test {
1266 use super::*;
1267 use crate::ops::array::Pad;
1268 use DataFormat::*;
1269
1270 #[test]
1271 fn onnx_basic_convinteger() {
1272 let op = Conv {
1273 pool_spec: PoolSpec {
1274 data_format: NCHW,
1275 kernel_shape: tvec!(2, 2),
1276 padding: Valid,
1277 dilations: None,
1278 strides: None,
1279 input_channels: 1,
1280 output_channels: 1,
1281 },
1282 kernel_fmt: KernelFormat::OIHW,
1283 group: 1,
1284 q_params: Some(i32::datum_type()),
1285 };
1286 let input = tvec!(
1287 rctensor4(&[[[[1u8, 2, 3], [4, 5, 6], [7, 8, 9]]]]),
1288 rctensor4(&[[[[1u8, 1], [1, 1]]]]),
1289 rctensor0(0u32),
1290 rctensor0(1u8),
1291 rctensor0(1.0f32),
1292 rctensor0(0u8),
1293 rctensor0(1.0f32),
1294 rctensor0(0i32),
1295 rctensor0(1.0f32),
1296 );
1297 let input = input.into_iter().map(IntoTValue::into_tvalue).collect::<TVec<_>>();
1298 let output = op.eval(input).unwrap();
1299 assert_eq!(*output[0], tensor4(&[[[[8i32, 12], [20, 24]]]]));
1300 }
1301
1302 #[test]
1303 fn valid_conv_absorbs_precursor_pad() -> TractResult<()> {
1304 let mut model = TypedModel::default();
1305 let wire = tvec!(model.add_source("source", f32::fact(dims!(1, 10)))?);
1306 let wire = model.wire_node(
1307 "pad",
1308 Pad {
1309 pads: vec![(0, 0), (1, 0)],
1310 mode: ops::array::PadMode::Constant(rctensor0(0f32)),
1311 },
1312 &wire,
1313 )?;
1314 let kernel = model.add_const("kernel", rctensor3(&[[[1f32, 2f32]]]))?;
1315 let bias = model.add_const("bias", rctensor0(0f32))?;
1316 let wire = model.wire_node(
1317 "conv",
1318 Conv {
1319 pool_spec: PoolSpec {
1320 data_format: crate::ops::nn::DataFormat::CHW,
1321 dilations: None,
1322 strides: None,
1323 kernel_shape: tvec![2],
1324 padding: Explicit(tvec![0], tvec![0]),
1325 input_channels: 1,
1326 output_channels: 1,
1327 },
1328 kernel_fmt: crate::ops::cnn::KernelFormat::OIHW,
1329 group: 1,
1330 q_params: None,
1331 },
1332 &[wire[0], kernel, bias],
1333 )?;
1334 model.set_output_outlets(&wire)?;
1335 model.declutter()?;
1336 assert_eq!(model.nodes().len(), 4); let cv = model.nodes()[3].op_as::<Conv>().unwrap();
1338 assert_eq!(cv.pool_spec.padding, Explicit(tvec![1], tvec![0])); Ok(())
1340 }
1341}