1use tract_data::itertools::izip;
2use tract_num_traits::Zero;
3
4use crate::internal::*;
5use crate::model::*;
6use crate::ops;
7use crate::ops::array::Pad;
8use crate::ops::array::PadMode;
9use crate::ops::binary::TypedBinOp;
10use crate::ops::cast::cast;
11use crate::ops::cnn::conv::lazy_im2col::LazyIm2Col;
12use crate::ops::cnn::conv::lazy_im2col::LazyIm2colParams;
13use crate::ops::cnn::wire_reshape_bias_for_bin;
14use crate::ops::cnn::PaddingSpec::*;
15use crate::ops::einsum::EinSum;
16use crate::ops::math::{add, div, mul, sub};
17use crate::ops::math::{Add, Div, Mul, Sub};
18use crate::ops::matmul::optimized::AddMatMulGeometry;
19use crate::ops::matmul::optimized::MapOutputAxisToInput;
20use crate::ops::matmul::pack::OptMatMulPack;
21use crate::ops::matmul::quant::wire_ensure_q8_flavour;
22use crate::ops::matmul::ModePicker;
23use crate::ops::nn::Reduce;
24
25use super::depth_wise::DepthWise;
26use super::im2col::Im2Col;
27use crate::ops::cnn::conv::KernelFormat;
28use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry, PoolSpec};
29use crate::ops::matmul::optimized::{OptMatMul, ProtoFusedSpec};
30use crate::ops::nn::{BaseDataShape, DataFormat, DataShape};
31
32use tract_linalg::frame::PackedFormat;
33use tract_linalg::mmm::MatMatMul;
34
35#[derive(Debug, Clone, new, Hash)]
36pub struct Conv {
37 pub pool_spec: PoolSpec,
38 pub kernel_fmt: KernelFormat,
39 pub group: usize,
40 pub q_params: Option<DatumType>,
45}
46
47impl Conv {
48 pub fn input_channels(&self) -> usize {
49 self.pool_spec.input_channels
50 }
51
52 pub fn output_channels(&self) -> usize {
53 self.pool_spec.output_channels
54 }
55
56 pub fn wire_kernel_as_g_o_ihw(
57 &self,
58 model: &mut TypedModel,
59 name: &str,
60 mut kernel: OutletId,
61 ) -> TractResult<TVec<OutletId>> {
62 let fact = model.outlet_fact(kernel)?;
63 for (ix, op) in self
64 .kernel_fmt
65 .kernel_as_group_o_ihw_ops(&fact.shape, self.group)
66 .into_iter()
67 .enumerate()
68 {
69 kernel = model.wire_node(format!("{name}.prep_kernel.{ix}"), op, &[kernel])?[0];
70 }
71 Ok(tvec!(kernel))
72 }
73
74 fn wire_pack_g_o_ihw(
75 &self,
76 model: &mut TypedModel,
77 name: &str,
78 format: PackedFormat,
79 kernel: OutletId,
80 ) -> TractResult<OutletId> {
81 Ok(model.wire_node(
82 format!("{name}.prep_kernel.pack"),
83 OptMatMulPack {
84 packers: vec![format],
85 k_axis: 2,
86 mn_axis: 1,
87 mode_picker: ModePicker::Single,
88 },
89 &[kernel],
90 )?[0])
91 }
92
93 fn wire_bias_as_non_linear(
95 &self,
96 model: &mut TypedModel,
97 name: &str,
98 bias: OutletId,
99 c_group_axis: usize,
100 ) -> TractResult<(ProtoFusedSpec, OutletId)> {
101 use tract_linalg::BinOp::Add;
102 let fact = model.outlet_fact(bias)?;
103 if fact.shape.volume().is_one() {
104 Ok((ProtoFusedSpec::BinScalar(2, Add), bias))
105 } else {
106 let bias = AxisOp::wire_split_axis(
107 model,
108 format!("{name}.reformat_bias"),
109 bias,
110 0,
111 self.group,
112 )?[0];
113 let pfs =
114 ProtoFusedSpec::BinPerRow(2, Add, MapOutputAxisToInput(tvec!((c_group_axis, 0))));
115 Ok((pfs, bias))
116 }
117 }
118
119 pub unsafe fn wire_as_quant_im2col(
120 &self,
121 model: &mut TypedModel,
122 name: &str,
123 wires: &[OutletId],
124 ) -> TractResult<TVec<OutletId>> {
125 ensure!(self.q_params.is_some());
126 use crate::ops::matmul::quant as qmm;
127
128 let c_dt = self.q_params.unwrap();
129 let &[mut x, mut kernel, bias, mut x0, x_scale, mut k0, mut k_scale, y0, y_scale] = wires
130 else {
131 bail!("Wrong number of inputs")
132 };
133 wire_ensure_q8_flavour(model, name, &mut kernel, "k", &mut k0, i8::datum_type())?;
134 wire_ensure_q8_flavour(model, name, &mut x, "x", &mut x0, i8::datum_type())?;
135
136 let b_fact = model.outlet_fact(x)?.clone();
137
138 let (_, _, k, n, mmm) = self.compute_geo(&b_fact)?;
139 let packing = 1; let output_shape = self.pool_spec.output_shape(&b_fact.shape)?;
141
142 if !model.outlet_fact(k_scale)?.shape.volume().is_one() {
143 if !output_shape.fmt.c_is_last() {
146 k_scale = model.wire_node(
147 format!("{name}.a_scale_axis_fix"),
148 AxisOp::Add(1),
149 &[k_scale],
150 )?[0];
151 }
152 }
153
154 let abc_scale = qmm::combine_scales(model, name, k_scale, x_scale, y_scale)?;
155
156 let im2col = model.wire_node(
157 format!("{name}.im2col"),
158 Im2Col::new(self.pool_spec.clone(), self.group, k, &b_fact.shape, mmm.clone())?,
159 &[x, x0],
160 )?[0];
161
162 let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
163 let g_o_ihw_as_i32 =
164 model.wire_node(format!("{name}.kernel_as_i32"), cast(i32::datum_type()), &g_o_ihw)?;
165 let sum_ker_g_c_k = model.wire_node(
166 format!("{name}.sum_ker_g_c_k"),
167 Reduce::new(tvec!(2), ops::nn::Reducer::Sum),
168 &g_o_ihw_as_i32,
169 )?;
170 let sum_ker_a_g_c =
171 model.wire_node(format!("{name}.rm_k"), AxisOp::Rm(2), &sum_ker_g_c_k)?;
172 let sum_ker_n_g_c = model.wire_node(
174 format!("{name}.sum_ker_n_g_c.axis_0"),
175 AxisOp::Add(0),
176 &sum_ker_a_g_c,
177 )?;
178 let hw_position = if self.pool_spec.data_format.c_is_last() { 1 } else { 3 };
179 let sum_ker = model.wire_node(
180 format!("{name}.sum_ker_n_g_c"),
181 AxisOp::Add(hw_position),
182 &sum_ker_n_g_c,
183 )?;
184
185 ensure!(mmm.packings()[packing].1.downcast_ref::<PackedFormat>().is_some());
186 let mut sum_x = model.wire_node(
187 format!("{name}.sum_x"),
188 super::QSumB { dt: b_fact.datum_type, n, r: mmm.nr(), k },
189 &[im2col],
190 )?;
191 sum_x = model.wire_node(format!("{name}.add_c"), AxisOp::Add(2), &sum_x)?;
193 if self.pool_spec.data_format.c_is_last() {
194 sum_x =
195 model.wire_node(format!("{name}.transpose_sum_b"), AxisOp::Move(3, 1), &sum_x)?;
196 }
197
198 let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?;
199 let bias_name = &model.node(bias.node).name;
200 let bias =
201 model.wire_node(format!("{bias_name}.cast"), cast(mmm.internal_type()), &[bias])?[0];
202 let wire = self.wire_mm_weights_bias(
203 model,
204 name,
205 im2col,
206 g_o_ihw[0],
207 bias,
208 mmm,
209 packing,
210 i32::datum_type(),
211 mmm_output_shape.clone().into(),
212 k,
213 c_axis,
214 h_axis,
215 )?;
216
217 let wire = qmm::compensate_zero_points(
218 model,
219 name,
220 wire[0],
221 k.to_dim(),
222 k0,
223 x0,
224 sum_ker[0],
225 sum_x[0],
226 )?;
227
228 let wire = self.wire_remove_group(model, name, &[wire], &mmm_output_shape, c_axis)?;
229 let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
230 let wire = qmm::requant(model, name, wire[0], c_dt, abc_scale, y0)?;
231 Self::wire_geo_reshape(model, name, &[wire], &output_shape)
232 }
233
234 pub fn wire_remove_group<D: DimLike>(
235 &self,
236 model: &mut TypedModel,
237 name: &str,
238 wire: &[OutletId],
239 mmm_output_shape: &[D],
240 c_axis: usize,
241 ) -> TractResult<TVec<OutletId>> {
242 let m = &mmm_output_shape[c_axis];
243 let op = if self.group == 1 {
244 AxisOp::Rm(c_axis - 1)
245 } else {
246 AxisOp::Reshape(
247 c_axis - 1,
248 tvec!(self.group.to_dim(), m.to_dim()),
249 tvec!(m.to_dim() * self.group),
250 )
251 };
252 model.wire_node(format!("{name}.reshape_group"), op, wire)
253 }
254
255 pub unsafe fn wire_as_im2col_pair(
256 &self,
257 model: &mut TypedModel,
258 name: &str,
259 wire: &[OutletId],
260 ) -> TractResult<TVec<OutletId>> {
261 let &[x, _kernel, bias] = wire else { bail!("Wrong number of inputs") };
262 let x_fact = model.outlet_fact(x)?.clone();
263 let b_dt = x_fact.datum_type;
264 let c_dt = crate::ops::matmul::output_type(x_fact.datum_type);
265
266 let (_, _, k, _, mmm) = self.compute_geo(&x_fact)?;
267 let geo_output_shape = self.pool_spec.output_shape(&x_fact.shape)?;
268 let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo_output_shape)?;
269
270 let padding = model.add_const(format!("{name}.b0"), Tensor::zero_scalar_dt(b_dt)?)?;
271
272 let mut wire: TVec<_> = wire.into();
273 wire[0] = model.wire_node(
274 format!("{name}.im2col"),
275 Im2Col::new(self.pool_spec.clone(), self.group, k, &x_fact.shape, mmm.clone())?,
276 &[wire[0], padding],
277 )?[0];
278
279 let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, wire[1])?;
280
281 let wire = self
282 .wire_mm_weights_bias(
283 model,
284 name,
285 wire[0],
286 g_o_ihw[0],
287 bias,
288 mmm,
289 0,
290 c_dt,
291 mmm_output_shape.clone().into(),
292 k.to_usize().unwrap(),
293 c_axis,
294 h_axis,
295 )
296 .context("in wire_opt_matmul")?;
297
298 let wire = self.wire_remove_group(model, name, &wire, &mmm_output_shape, c_axis)?;
299 let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
300 Self::wire_geo_reshape(model, name, &wire, &geo_output_shape)
301 }
302
303 fn mmm_output_shape<D: DimLike>(
305 &self,
306 output_shape: &BaseDataShape<D, TVec<D>>,
307 ) -> TractResult<(TVec<D>, usize, usize)> {
308 let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
309 let shape: BaseDataShape<D, TVec<D>> = output_shape.fmt.with_n().from_n_c_hw(
310 output_shape.n().cloned().unwrap_or_else(|| 1.into()),
311 output_shape.c().clone(),
312 tvec!(geo_collapsed_out),
313 )?;
314 let mut mmm_output_shape: TVec<D> = shape.shape.clone();
315 let mut c_axis = shape.c_axis();
316 let mut h_axis = shape.h_axis();
317 mmm_output_shape[shape.c_axis()] = mmm_output_shape[c_axis].clone() / self.group;
318 mmm_output_shape.insert(c_axis, self.group.into());
319 if h_axis > c_axis {
320 h_axis += 1;
321 }
322 c_axis += 1;
323 Ok((mmm_output_shape, c_axis, h_axis))
324 }
325
326 fn wire_rm_n_if_needed(
327 &self,
328 model: &mut TypedModel,
329 name: &str,
330 wire: &[OutletId],
331 ) -> TractResult<TVec<OutletId>> {
332 if self.pool_spec.data_format.has_n() {
333 Ok(wire.into())
334 } else {
335 model.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), wire)
336 }
337 }
338
339 fn wire_geo_reshape<D: DimLike>(
340 model: &mut TypedModel,
341 name: &str,
342 wire: &[OutletId],
343 output_shape: &BaseDataShape<D, TVec<D>>,
344 ) -> TractResult<TVec<OutletId>> {
345 let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
346 model
347 .wire_node(
348 name,
349 AxisOp::Reshape(
350 output_shape.h_axis(),
351 tvec!(geo_collapsed_out.to_dim()),
352 output_shape.hw_dims().iter().map(|d| d.to_dim()).collect(),
353 ),
354 wire,
355 )
356 .context("in wire_geo_reshape")
357 }
358
359 pub unsafe fn wire_as_lazy_im2col(
360 &self,
361 model: &mut TypedModel,
362 name: &str,
363 wire: &[OutletId],
364 ) -> TractResult<TVec<OutletId>> {
365 let &[mut x, kernel, bias] = wire else { bail!("Wrong number of inputs") };
366 let mut x_fact = model.outlet_fact(x)?.clone();
367 let (geo, m, k, n, mmm) = self.compute_geo(&x_fact)?;
368 let packing = 0;
369 debug!("{name} as lazy_im2col: m={m} k={k} n={n} {mmm:?}");
370 let input_shape = x_fact.shape.as_concrete().unwrap().to_vec();
371 let mut geo = geo.to_concrete(&input_shape)?.into_owned();
372 let mut input_shape: DataShape = self.pool_spec.data_format.shape(input_shape.into())?;
373 let padding = self.pool_spec.computed_padding(input_shape.hw_dims());
374 if padding.iter().any(|axis| axis.pad_before != 0 || axis.pad_after != 0) {
375 let mut pads = vec![(0, 0); x_fact.rank()];
376 for (ix, ax) in padding.iter().enumerate() {
377 pads[input_shape.h_axis() + ix] = (ax.pad_before, ax.pad_after);
378 }
379 let op = crate::ops::array::Pad {
380 mode: crate::ops::array::PadMode::Constant(
381 Tensor::zero_scalar_dt(x_fact.datum_type)?.into_arc_tensor(),
382 ),
383 pads,
384 };
385 x = model.wire_node(format!("{name}.pad"), op, &[x])?[0];
386 let valid_pool_spec = PoolSpec { padding: Valid, ..self.pool_spec.clone() };
387 x_fact = model.outlet_fact(x)?.clone();
388 let concrete_shape = x_fact.shape.as_concrete().unwrap();
389 input_shape = valid_pool_spec.data_format.shape(concrete_shape.into())?;
390 geo = valid_pool_spec
391 .compute_geo(&x_fact.shape)?
392 .to_concrete(concrete_shape)?
393 .into_owned();
394 }
395 let c_dt = crate::ops::matmul::output_type(x_fact.datum_type);
396 let c_stride = input_shape.c_stride();
397 let size_of_b = x_fact.datum_type.size_of() as isize;
398 let n_byte_offsets: Vec<isize> =
399 geo.patch.centers_offsets().into_iter().map(|x| x * size_of_b).collect();
400 let k_byte_offsets: Vec<isize> = (0..self.input_channels())
401 .flat_map(|ici| {
402 geo.patch
403 .standard_layout_data_field
404 .iter()
405 .map(move |x| (x + (ici * c_stride) as isize) * size_of_b)
406 })
407 .collect();
408 let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo.output_shape)?;
409 let packer = mmm.packings()[packing]
410 .1
411 .downcast_ref::<PackedFormat>()
412 .with_context(|| {
413 format_err!(
414 "Quand Im2Col expects regular packed format, got {:?}",
415 mmm.packings()[packing].1
416 )
417 })?
418 .clone();
419 let params = LazyIm2colParams { packer, n_byte_offsets, k_byte_offsets };
420 let x = model.wire_node(
421 format!("{name}.lazyIm2col"),
422 LazyIm2Col { params: Arc::new(params) },
423 &[x],
424 )?[0];
425
426 let kernel = self.wire_kernel_as_g_o_ihw(model, name, kernel)?[0];
427 let wire = self.wire_mm_weights_bias(
428 model,
429 name,
430 x,
431 kernel,
432 bias,
433 mmm,
434 packing,
435 c_dt,
436 mmm_output_shape.clone().into(),
437 k,
438 c_axis,
439 h_axis,
440 )?;
441
442 let wire = self.wire_remove_group(model, name, &wire, &mmm_output_shape, c_axis)?;
443 let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
444 Self::wire_geo_reshape(model, name, &wire, &geo.output_shape)
445 }
446
447 #[allow(clippy::type_complexity)]
448 fn compute_geo(
449 &self,
450 input_fact: &TypedFact,
451 ) -> TractResult<(PoolGeometry, usize, usize, TDim, Box<dyn MatMatMul>)> {
452 let b_dt = input_fact.datum_type;
453 let acc = if b_dt.is_float() { b_dt } else { i32::datum_type() };
454
455 let geo = self.pool_spec.compute_geo(&input_fact.shape)?;
456
457 trace!("output channels: {:?}", self.output_channels());
458 let m = self.output_channels() / self.group;
459 let k = self.input_channels() * self.pool_spec.kernel_shape.iter().product::<usize>()
460 / self.group;
461 let n: TDim =
462 self.pool_spec.output_shape(&input_fact.shape)?.hw_dims().iter().cloned().product();
463
464 let mmm = tract_linalg::ops()
465 .mmm(acc, Some(m), Some(k), n.to_usize().ok())
466 .with_context(|| format!("No multiplier for {acc:?}, {m}x{k}x{n}",))?;
467
468 Ok((geo, m, k, n, mmm))
469 }
470
471 #[allow(clippy::too_many_arguments)]
472 fn wire_mm_weights_bias(
473 &self,
474 model: &mut TypedModel,
475 name: &str,
476 input: OutletId,
477 g_o_ihw: OutletId,
478 bias: OutletId,
479 mmm: Box<dyn MatMatMul>,
480 packing: usize,
481 c_datum_type: DatumType,
482 mmm_output_shape: ShapeFact,
483 k: usize,
484 c_m_axis: usize,
485 c_n_axis: usize,
486 ) -> TractResult<TVec<OutletId>> {
487 ensure!(model.outlet_fact(bias)?.datum_type == mmm.internal_type());
488 let a_pack = mmm.packings()[packing]
489 .0
490 .downcast_ref::<PackedFormat>()
491 .context("Conv expects wights in regular packed format")?
492 .clone();
493 let packed_ker = self
494 .wire_pack_g_o_ihw(model, name, a_pack, g_o_ihw)
495 .context("in kernel_as_packed_as")?;
496 let (mut c_to_a_axis_mapping, mut c_to_b_axis_mapping) = (tvec!(), tvec!());
497
498 c_to_a_axis_mapping.push((c_m_axis - 1, 0)); c_to_b_axis_mapping.push((0, 0)); c_to_b_axis_mapping.push((c_m_axis - 1, 1)); let geo = AddMatMulGeometry {
503 k: k.to_dim(),
504 c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
505 c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
506 };
507 let mut ops: Vec<ProtoFusedSpec> =
508 vec![ProtoFusedSpec::AddMatMul { geo, a: 1, b: 0, packings: vec![(packing, None)] }];
509 let mut wires: TVec<OutletId> = tvec!(input, packed_ker);
510 let bias_fact = model.outlet_fact(bias)?;
511 if bias_fact.konst.is_none() || !bias_fact.konst.as_ref().unwrap().is_all_zero()? {
512 let (fused, bias) = self.wire_bias_as_non_linear(model, name, bias, c_m_axis - 1)?;
513 wires.push(bias);
514 ops.push(fused);
515 }
516 ops.push(ProtoFusedSpec::Store(vec![unsafe { mmm.c_view(c_m_axis, c_n_axis) }]));
517 model.wire_node(
518 format!("{name}.matmatmul"),
519 OptMatMul::new(
520 vec![mmm],
521 ModePicker::Single,
522 c_datum_type.fact(mmm_output_shape),
523 c_m_axis,
524 c_n_axis,
525 ops,
526 packing == 0 && self.group == 1,
527 )?,
528 &wires,
529 )
530 }
531
532 pub fn wire_as_depth_wise(
533 &self,
534 model: &mut TypedModel,
535 name: &str,
536 wire: &[OutletId],
537 ) -> TractResult<OutletId> {
538 let &[x, kernel, mut bias] = wire else { bail!("Wrong number of inputs") };
539 let x_fact = model.outlet_fact(x)?.clone();
540 let x_shape = x_fact.shape.as_concrete().unwrap();
541 let ConcretePoolGeometry { input_shape, patch, output_shape } =
542 self.pool_spec.compute_geo(&x_fact.shape)?.to_concrete(x_shape)?.into_owned();
543 let kernel = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
544 let c_axis = self.pool_spec.data_format.shape(x_shape)?.c_axis();
545 bias = wire_reshape_bias_for_bin(
546 model,
547 name,
548 bias,
549 x_fact.rank(),
550 c_axis,
551 self.output_channels(),
552 )?[0];
553 let op = DepthWise::new(patch, input_shape, output_shape);
554 Ok(model.wire_node(name, op, &[x, kernel[0], bias])?[0])
555 }
556
557 fn declutter_stride_slice_to_downsample(
558 &self,
559 model: &TypedModel,
560 node: &TypedNode,
561 ) -> TractResult<Option<TypedModelPatch>> {
562 let spatial_rank = self.pool_spec.rank();
563 if let Some(axis) = (0..spatial_rank).find(|&ax| {
564 self.pool_spec.stride(ax) > 1
565 && self.pool_spec.padding.valid_dim(ax, self.pool_spec.stride(ax) == 1)
566 && (self.pool_spec.kernel_shape[ax] == 1
567 || self.pool_spec.dilation(ax) % self.pool_spec.stride(ax) == 0)
568 }) {
569 let input_fact = model.outlet_fact(node.inputs[0])?;
570 let downsample_factor = self.pool_spec.stride(axis);
571 let mut new_op = self.clone();
572 if new_op.pool_spec.dilation(axis) > 1 {
573 new_op.pool_spec.dilations.as_mut().unwrap()[axis] /= downsample_factor;
574 }
575 new_op.pool_spec.strides.as_mut().unwrap()[axis] /= downsample_factor;
576 let mut patch = TypedModelPatch::default();
577 let mut taps = patch.taps(model, &node.inputs)?;
578 let shape = self.pool_spec.data_format.shape(&input_fact.shape)?;
579 taps[0] = patch.wire_node(
580 format!("{}.downsample.{}", node.name, axis),
581 crate::ops::Downsample::new(axis + shape.h_axis(), downsample_factor as isize, 0),
582 &[taps[0]],
583 )?[0];
584 let id = patch.wire_node(&*node.name, new_op, &taps)?[0];
585 patch.shunt_outside(model, OutletId::new(node.id, 0), id)?;
586 return Ok(Some(patch));
587 }
588 Ok(None)
589 }
590
591 fn declutter_as_einsum(
592 &self,
593 model: &TypedModel,
594 node: &TypedNode,
595 ) -> TractResult<Option<TypedModelPatch>> {
596 let (input_facts, output_facts) = model.node_facts(node.id)?;
597 let full_input_shape = input_facts[0].shape.to_tvec();
598 let input_shape = self.pool_spec.data_format.shape(&full_input_shape)?;
599 if self.group == 1
600 && self.pool_spec.strides().iter().all(|s| *s == 1)
601 && self.pool_spec.dilations().iter().all(|d| *d == 1)
602 && self.pool_spec.kernel_shape.iter().product::<usize>() == 1
603 && self
604 .pool_spec
605 .computed_padding(input_shape.hw_dims())
606 .iter()
607 .all(|pad| pad.pad_after.is_zero() && pad.pad_before.is_zero())
608 {
609 let mut axes = self.axes_mapping(&input_facts, &output_facts)?;
610 let mut patch = TypedModelPatch::new("declutter_as_einsum");
611 let mut taps = patch.taps(model, &node.inputs)?;
612 let name = &node.name;
613 let co = self.output_channels();
614 taps[1] =
615 self.wire_kernel_as_g_o_ihw(&mut patch, &format!("{name}.filters"), taps[1])?[0];
616 taps[1] =
617 patch.wire_node(format!("{name}.filters_as_co_ci"), AxisOp::Rm(0), &[taps[1]])?[0];
618
619 while axes.rank(InOut::In(1)) > 0 {
620 axes = axes.remove_axis_occurency(InOut::In(1), 0)?;
621 }
622 axes = axes
623 .with_extra_axis_occurency('O', InOut::In(1), 0)?
624 .with_extra_axis_occurency('I', InOut::In(1), 1)?;
625
626 let bias_fact = input_facts[2];
627 let wire = if self.q_params.is_some() {
628 if bias_fact.rank() == 1 {
629 axes = axes.linking('O', (InOut::In(2), 0))?;
630 }
631 let op = EinSum { axes, operating_dt: i32::datum_type(), q_params: self.q_params };
632 patch.wire_node(format!("{name}.einsum"), op, &taps)?[0]
633 } else {
634 axes = axes.remove_slot(InOut::In(2))?;
635 let op = EinSum { axes, operating_dt: input_facts[0].datum_type, q_params: None };
636 let mut wire = patch.wire_node(format!("{name}.einsum"), op, &taps[0..2])?[0];
637
638 if !bias_fact.konst.as_ref().map(|f| f.is_zero()).transpose()?.unwrap_or(false) {
639 let bias_current_shape =
640 if bias_fact.rank() == 0 { tvec!() } else { tvec!(co.to_dim()) };
641 let mut bias_shape = tvec!(1.to_dim(); input_shape.rank());
642 if bias_fact.rank() > 0 {
643 bias_shape[input_shape.c_axis()] = co.to_dim();
644 }
645 let b = patch.wire_node(
646 format!("{name}.bias.reshape"),
647 AxisOp::Reshape(0, bias_current_shape, bias_shape),
648 &[taps[2]],
649 )?[0];
650 wire = patch.wire_node(
651 format!("{name}.bias"),
652 crate::ops::math::add(),
653 &[wire, b],
654 )?[0];
655 }
656 wire
657 };
658 patch.node_mut(wire.node).name = node.name.to_string();
659 patch.shunt_outside(model, node.id.into(), wire)?;
660 return Ok(Some(patch));
661 }
662 Ok(None)
663 }
664
665 fn declutter_precursor_padding(
666 &self,
667 model: &TypedModel,
668 node: &TypedNode,
669 ) -> TractResult<Option<TypedModelPatch>> {
670 if matches!(self.pool_spec.padding, ExplicitOnnxPool(_, _, _) | SameLower | SameUpper) {
671 return Ok(None);
672 }
673 let prec = model.node(node.inputs[0].node);
674 let pad = if let Some(pad) = prec.op_as::<Pad>() { pad } else { return Ok(None) };
675 let value = if let PadMode::Constant(c) = &pad.mode {
676 c
677 } else {
678 return Ok(None);
679 };
680 let shape = self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
681 if !value.is_zero()?
682 || (self.pool_spec.data_format.has_n() && pad.pads[0] != (0, 0))
683 || pad.pads[shape.c_axis()] != (0, 0)
684 {
685 return Ok(None);
686 }
687 let mut before: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.0).collect();
688 let mut after: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.1).collect();
689 if let Explicit(bef, aft) = &self.pool_spec.padding {
690 izip!(&mut before, bef).for_each(|(pad, cv)| *pad += cv);
691 izip!(&mut after, aft).for_each(|(pad, cv)| *pad += cv);
692 }
693 let padding = Explicit(before, after);
694 let mut new = self.clone();
695 new.pool_spec.padding = padding;
696 let mut patch = TypedModelPatch::default();
697 let mut wire = patch.taps(model, &node.inputs)?;
698 wire[0] = patch.tap_model(model, prec.inputs[0])?;
699 let wire = patch.wire_node(&node.name, new, &wire)?;
700 patch.shunt_outside(model, node.id.into(), wire[0])?;
701 Ok(Some(patch))
702 }
703
704 fn declutter_channel_arithmetic_succ(
705 &self,
706 model: &TypedModel,
707 node: &TypedNode,
708 ) -> TractResult<Option<TypedModelPatch>> {
709 if self.q_params.is_some() || self.group != 1 {
710 return Ok(None);
711 }
712 let &[succ_outlet] = &*node.outputs[0].successors else { return Ok(None) };
713 let succ = model.node(succ_outlet.node);
714 let Some(bin) = succ.op_as::<TypedBinOp>() else { return Ok(None) };
715 let other_input = succ.inputs[1 - succ_outlet.slot];
716 let axes_mapping = model.node_axes_mapping(succ.id)?;
717 let input_shape =
718 self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
719 let conv_c_axis = input_shape.c_axis();
720 if axes_mapping.axis((InOut::In(succ_outlet.slot), conv_c_axis))?.inputs
721 [1 - succ_outlet.slot]
722 .len()
723 != 1
724 {
725 return Ok(None);
726 };
727 let mut other_expected_shape = tvec!(1.to_dim(); input_shape.rank());
728 other_expected_shape[conv_c_axis] = self.output_channels().to_dim();
729 if *other_expected_shape != *model.outlet_fact(other_input)?.shape {
730 return Ok(None);
731 }
732
733 let mut patch = TypedModelPatch::default();
734 let [input, mut kernel, mut bias] = &*patch.taps(model, &node.inputs)? else {
735 panic!("Expect three inputs");
736 };
737 let name = &node.name;
738 let succ_name = &succ.name;
739
740 let operand = patch.tap_model(model, other_input)?;
741
742 let renamed_bias = format!("{name}.{succ_name}.bias");
743 let renamed_kernel = format!("{name}.{succ_name}.kernel");
744 bias = wire_reshape_bias_for_bin(
745 &mut patch,
746 format!("{renamed_bias}.reshape"),
747 bias,
748 1,
749 0,
750 self.output_channels(),
751 )?[0];
752
753 let operand = wire_reshape_bias_for_bin(
754 &mut patch,
755 format!("{renamed_bias}.reshape_operand"),
756 operand,
757 1,
758 0,
759 self.output_channels(),
760 )?[0];
761
762 let operand_fact = patch.outlet_fact(operand)?.shape.to_tvec();
763 let kernel_fact = patch.outlet_fact(kernel)?;
764 let mut operand_shape_for_kernel = tvec!(1.to_dim(); 2 + input_shape.hw_rank());
765 operand_shape_for_kernel[self.kernel_fmt.o_axis(&kernel_fact.shape)] =
766 self.output_channels().to_dim();
767 let operand_for_kernel = patch.wire_node(
768 format!("{renamed_kernel}.reshape_operand"),
769 AxisOp::Reshape(0, operand_fact, operand_shape_for_kernel),
770 &[operand],
771 )?[0];
772
773 if bin.0.is::<Sub>() && succ_outlet.slot == 0 {
774 bias = patch.wire_node(&renamed_bias, sub(), &[bias, operand])?[0];
775 } else if bin.0.is::<Sub>() {
776 bias = patch.wire_node(&renamed_bias, sub(), &[operand, bias])?[0];
777 } else if bin.0.is::<Div>() && succ_outlet.slot == 0 {
778 bias = patch.wire_node(&renamed_bias, div(), &[bias, operand])?[0];
779 kernel = patch.wire_node(&renamed_kernel, div(), &[kernel, operand_for_kernel])?[0];
780 } else if bin.0.is::<Div>() {
781 bias = patch.wire_node(&renamed_bias, div(), &[operand, bias])?[0];
782 kernel = patch.wire_node(&renamed_kernel, div(), &[operand_for_kernel, kernel])?[0];
783 } else if bin.0.is::<Add>() {
784 bias = patch.wire_node(&renamed_bias, add(), &[bias, operand])?[0];
785 } else if bin.0.is::<Mul>() {
786 bias = patch.wire_node(&renamed_bias, mul(), &[bias, operand])?[0];
787 kernel = patch.wire_node(&renamed_kernel, mul(), &[kernel, operand_for_kernel])?[0];
788 } else {
789 return Ok(None);
790 };
791 let wire = patch.wire_node(&node.name, self.clone(), &[*input, kernel, bias])?[0];
792 patch.shunt_outside(model, succ_outlet.node.into(), wire)?;
793 Ok(Some(patch))
794 }
795}
796
797impl Op for Conv {
798 fn name(&self) -> Cow<str> {
799 "Conv".into()
800 }
801
802 fn info(&self) -> TractResult<Vec<String>> {
803 let mut info = self.pool_spec.info();
804 info.push(format!("Kernel {:?} (groups:{})", self.kernel_fmt, self.group));
805 Ok(info)
806 }
807
808 fn validation(&self) -> Validation {
809 Validation::Rounding
810 }
811
812 op_as_typed_op!();
813}
814
815impl EvalOp for Conv {
816 fn is_stateless(&self) -> bool {
817 true
818 }
819
820 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
821 let mut model = TypedModel::default();
822 let wire: TVec<OutletId> = inputs
823 .iter()
824 .enumerate()
825 .map(|(ix, v)| model.add_source(format!("source.{ix}"), v.datum_type().fact(v.shape())))
826 .collect::<TractResult<_>>()?;
827 let wire = unsafe {
828 if self.q_params.is_some() {
829 self.wire_as_quant_im2col(&mut model, "im2col-adhoc", &wire)?
830 } else {
831 self.wire_as_im2col_pair(&mut model, "im2col-adhoc", &wire)?
832 }
833 };
834 model.set_output_outlets(&wire)?;
835 model.into_runnable()?.run(inputs)
836 }
837}
838
839impl TypedOp for Conv {
840 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
841 ensure!(self.q_params.is_some() || inputs[0].datum_type.is_float());
842 let q_inputs = if self.q_params.is_some() { 6 } else { 0 };
843 if inputs.len() != 3 + q_inputs {
844 bail!("Wrong number of inputs: expected {} got {}", 3 + q_inputs, inputs.len());
845 }
846 if self.q_params.is_some() {
847 ensure!(inputs[2].datum_type == i32::datum_type());
848 ensure!(inputs[3].datum_type == i32::datum_type());
849 ensure!(inputs[4].datum_type.is_float());
850 ensure!(inputs[5].datum_type == i32::datum_type());
851 ensure!(inputs[6].datum_type.is_float());
852 ensure!(inputs[7].datum_type == i32::datum_type());
853 ensure!(inputs[8].datum_type.is_float());
854 }
855 ensure!(self.pool_spec.rank() + 2 == inputs[1].rank());
856 if self.pool_spec.data_format.shape(&*inputs[0].shape)?.c()
857 != &self.input_channels().to_dim()
858 {
859 bail!(
860 "Inconsistent convolution: input is {:?}, but kernel expects {} input channels.\n{:?}",
861 inputs[0],
862 self.input_channels(),
863 self
864 );
865 }
866 if let ExplicitOnnxPool(bef, after, _) | Explicit(bef, after) = &self.pool_spec.padding {
867 anyhow::ensure!(bef.len() == self.pool_spec.rank());
868 anyhow::ensure!(after.len() == self.pool_spec.rank());
869 }
870 ensure!(
871 inputs[2].rank() == 0
872 || (inputs[2].rank() == 1
873 && inputs[2].shape.volume() == self.output_channels().to_dim()),
874 "Bias should be scalar or a vector with one value per output channel. Output channels is {}, bias is {:?}",
875 self.output_channels(),
876 inputs[2]
877 );
878 let mut fact = self.pool_spec.output_facts(inputs)?.remove(0);
879 if let Some(dt) = self.q_params {
880 fact.datum_type = dt;
881 } else {
882 ensure!(
883 inputs[0].datum_type == inputs[1].datum_type,
884 "Convolution input, weights and bias must have the same type, got {inputs:?}",
885 )
886 }
887 Ok(tvec!(fact))
888 }
889
890 fn axes_mapping(
891 &self,
892 inputs: &[&TypedFact],
893 outputs: &[&TypedFact],
894 ) -> TractResult<AxesMapping> {
895 let fact = &inputs[0];
896 let shape = self.pool_spec.data_format.shape(&fact.shape)?;
897 let mut axes = AxesMapping::disconnected(inputs, outputs)?
898 .renaming((InOut::In(0), shape.c_axis()), 'I')?
899 .renaming((InOut::Out(0), shape.c_axis()), 'O')?;
900 if let Some(n_axis) = shape.n_axis() {
901 axes = axes
902 .renaming((InOut::In(0), n_axis), 'N')?
903 .linking('N', (InOut::Out(0), n_axis))?;
904 }
905 let h_axis = shape.h_axis();
906 let geo = "HWXYZ".chars().chain('a'..);
907 let kernel_spatial_shape = &self.pool_spec.kernel_shape;
908 let padding = self.pool_spec.computed_padding(shape.hw_dims());
909 for ((ix, &dim), repr) in kernel_spatial_shape.iter().enumerate().zip(geo) {
910 if dim == 1
911 && self.pool_spec.dilation(ix) == 1
912 && self.pool_spec.stride(ix) == 1
913 && padding[ix].pad_before.is_zero()
914 && padding[ix].pad_after.is_zero()
915 {
916 axes = axes
917 .renaming((InOut::In(0), ix + h_axis), repr)?
918 .linking(repr, (InOut::Out(0), ix + h_axis))?;
919 }
920 }
921 if self.q_params.is_some() {
922 for (qp_ix, qp) in inputs.iter().enumerate().skip(3) {
923 if qp.rank() == 1 {
924 axes = match qp_ix {
925 3 | 4 => axes.linking('I', (InOut::In(qp_ix), 0))?,
926 5 | 6 => axes.linking('O', (InOut::In(qp_ix), 0))?,
927 7 | 8 => axes.linking('O', (InOut::In(qp_ix), 0))?,
928 _ => unreachable!(),
929 };
930 }
931 }
932 }
933 Ok(axes)
934 }
935
936 fn declutter(
937 &self,
938 model: &TypedModel,
939 node: &TypedNode,
940 ) -> TractResult<Option<TypedModelPatch>> {
941 macro_rules! pass {
942 ($func:ident) => {
943 if let Some(mut r) = self.$func(model, node).context(stringify!($func))? {
944 trace!(stringify!($func));
945 r.push_context(stringify!($func));
946 return Ok(Some(r));
947 }
948 };
949 }
950 pass!(declutter_stride_slice_to_downsample);
951 pass!(declutter_as_einsum);
952 pass!(declutter_channel_arithmetic_succ);
953 pass!(declutter_precursor_padding);
954 Ok(None)
955 }
956
957 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
958 let shape = self.pool_spec.data_format.shape(inputs[0].shape.to_tvec())?;
959 let kernel_spatial_shape = &self.pool_spec.kernel_shape;
960 let output_dims = self.pool_spec.padding.compute(
961 shape.hw_dims(),
962 kernel_spatial_shape,
963 &self
964 .pool_spec
965 .dilations
966 .clone()
967 .unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
968 &self.pool_spec.strides.clone().unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
969 );
970 let n_output_points: TDim =
971 output_dims.iter().map(|d| d.convoluted.clone()).product::<TDim>();
972 let n_output_channels = self.output_channels().to_dim();
973 let kernel_surface = kernel_spatial_shape.iter().product::<usize>().to_dim();
974 let one = 1.to_dim();
975 Ok(tvec!((
976 Cost::FMA(inputs[0].datum_type),
977 shape.n().cloned().unwrap_or(one)
978 * shape.c()
979 * n_output_channels
980 * n_output_points
981 * kernel_surface
982 / self.group
983 )))
984 }
985
986 fn change_axes(
987 &self,
988 model: &TypedModel,
989 node: &TypedNode,
990 io: InOut,
991 change: &AxisOp,
992 ) -> TractResult<Option<AxisChangeConsequence>> {
993 if io == InOut::In(1) {
994 return Ok(None);
995 }
996 if io == InOut::In(2) {
997 if let &AxisOp::Rm(_) = change {
998 return Ok(Some(AxisChangeConsequence {
999 substitute_op: Some(Box::new(self.clone())),
1000 wire_changes: tvec!(),
1001 }));
1002 }
1003 }
1004 let full_input_shape = model.outlet_fact(node.inputs[0])?.shape.to_tvec();
1005 let shape = self.pool_spec.data_format.shape(full_input_shape.clone())?;
1006 if let Some(n) = shape.n_axis() {
1008 assert_eq!(n, 0);
1009 if change == &AxisOp::Rm(n) {
1010 let op = Conv { pool_spec: self.pool_spec.dispose_n_axis(), ..self.clone() };
1011 return Ok(Some(AxisChangeConsequence {
1012 substitute_op: Some(Box::new(op)),
1013 wire_changes: tvec!(
1014 (InOut::In(0), change.clone()),
1015 (InOut::Out(0), change.clone())
1016 ),
1017 }));
1018 }
1019 if change.transform_axis(n).map(|axis| axis > 0).unwrap_or(true) {
1020 return Ok(None);
1021 }
1022 }
1023 let (new_format, axis_move) = match self.pool_spec.data_format {
1025 DataFormat::NCHW => {
1026 (DataFormat::NHWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
1027 }
1028 DataFormat::CHW => {
1029 (DataFormat::HWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
1030 }
1031 DataFormat::NHWC => (DataFormat::NCHW, AxisOp::Move(shape.c_axis(), 1)),
1032 DataFormat::HWC => (DataFormat::CHW, AxisOp::Move(shape.c_axis(), 0)),
1033 };
1034 if *change == axis_move {
1035 let mut new_op = self.clone();
1036 new_op.pool_spec.data_format = new_format;
1037 return Ok(Some(AxisChangeConsequence {
1038 substitute_op: Some(Box::new(new_op)),
1039 wire_changes: tvec!(
1040 (InOut::In(0), change.clone()),
1041 (InOut::Out(0), change.clone())
1042 ),
1043 }));
1044 }
1045 use AxisOp::*;
1047 let h_axis = shape.h_axis();
1048 let hw_axes = shape.hw_axes();
1049 let kh_axis = self.kernel_fmt.h_axis();
1050 let (geo_adjusted, kernel_adjusted) = match change {
1051 Rm(a)
1052 if hw_axes.contains(a)
1053 && hw_axes.len() > 1
1054 && self.pool_spec.dilation(a - h_axis) == 1
1055 && self.pool_spec.stride(a - h_axis) == 1
1056 && self.pool_spec.kernel_shape[a - h_axis] == 1 =>
1057 {
1058 let geo_axis = a - h_axis;
1059 (Rm(geo_axis), Rm(kh_axis + geo_axis))
1060 }
1061 Add(a) if hw_axes.contains(a) => (Add(a - h_axis), Add(a - h_axis + kh_axis)),
1062 Move(f, t) if hw_axes.contains(f) && hw_axes.contains(t) => {
1063 (Move(f - h_axis, t - h_axis), Move(f - h_axis + kh_axis, t - h_axis + kh_axis))
1064 }
1065 _ => return Ok(None),
1066 };
1067 let pool_spec = self.pool_spec.change_geo_axes(&geo_adjusted)?;
1068 let new_op = Conv { pool_spec, ..self.clone() };
1069 Ok(Some(AxisChangeConsequence {
1070 substitute_op: Some(Box::new(new_op)),
1071 wire_changes: tvec!(
1072 (InOut::In(0), change.clone()),
1073 (InOut::In(1), kernel_adjusted),
1074 (InOut::Out(0), change.clone())
1075 ),
1076 }))
1077 }
1078
1079 fn codegen(
1080 &self,
1081 model: &TypedModel,
1082 node: &TypedNode,
1083 ) -> TractResult<Option<TypedModelPatch>> {
1084 let input_fact = model.outlet_fact(node.inputs[0])?;
1085 unsafe {
1086 if self.q_params.is_some() {
1087 let mut patch = TypedModelPatch::default();
1088 let inputs = patch.taps(model, &node.inputs)?;
1089 let wire = self
1090 .wire_as_quant_im2col(&mut patch, &node.name, &inputs)
1091 .context("in wire_as_quant_im2col")?;
1092 patch.shunt_outside(model, node.id.into(), wire[0])?;
1093 patch.obliterate(node.id)?;
1094 Ok(Some(patch.with_context("quantized-codegen")))
1095 } else if input_fact
1096 .shape
1097 .as_concrete()
1098 .map(|s| {
1099 should_use_lazy(
1100 &self.pool_spec.data_format.shape(s.into()).unwrap(),
1101 &self.pool_spec,
1102 self.group,
1103 )
1104 })
1105 .unwrap_or(false)
1106 {
1107 let mut patch = TypedModelPatch::new("wire_as_lazy_im2col");
1108 let inputs = patch.taps(model, &node.inputs)?;
1109 let wire = self
1110 .wire_as_lazy_im2col(&mut patch, &node.name, &inputs)
1111 .context("wire_as_lazy_im2col")?[0];
1112 patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1113 patch.obliterate(node.id)?;
1114 Ok(Some(patch))
1115 } else if self.group != 1
1116 && self.group == self.output_channels()
1117 && self.group == self.input_channels()
1118 && input_fact.shape.as_concrete().is_some()
1119 {
1120 let mut patch = TypedModelPatch::default();
1121 let inputs = patch.taps(model, &node.inputs)?;
1122 let wire = self
1123 .wire_as_depth_wise(&mut patch, &node.name, &inputs)
1124 .context("wire_as_depth_wise")?;
1125 patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1126 patch.obliterate(node.id)?;
1127 Ok(Some(patch))
1128 } else {
1129 let mut patch = TypedModelPatch::default();
1130 let inputs = patch.taps(model, &node.inputs)?;
1131 let wire = self
1132 .wire_as_im2col_pair(&mut patch, &node.name, &inputs)
1133 .context("in wire_as_im2col_pair")?[0];
1134 patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1135 patch.obliterate(node.id)?;
1136 Ok(Some(patch))
1137 }
1138 }
1139 }
1140
1141 as_op!();
1142}
1143
1144fn should_use_lazy(input_shape: &DataShape, pool_spec: &PoolSpec, group: usize) -> bool {
1145 input_shape.n().unwrap_or(&1) == &1
1146 && group == 1
1147 && pool_spec.kernel_shape.iter().product::<usize>() > 5
1148}
1149
1150#[allow(non_snake_case)]
1151#[cfg(test)]
1152mod test {
1153 use super::*;
1154 use crate::ops::array::Pad;
1155 use DataFormat::*;
1156
1157 #[test]
1158 fn onnx_basic_convinteger() {
1159 let op = Conv {
1160 pool_spec: PoolSpec {
1161 data_format: NCHW,
1162 kernel_shape: tvec!(2, 2),
1163 padding: Valid,
1164 dilations: None,
1165 strides: None,
1166 input_channels: 1,
1167 output_channels: 1,
1168 },
1169 kernel_fmt: KernelFormat::OIHW,
1170 group: 1,
1171 q_params: Some(i32::datum_type()),
1172 };
1173 let input = tvec!(
1174 rctensor4(&[[[[1u8, 2, 3], [4, 5, 6], [7, 8, 9]]]]),
1175 rctensor4(&[[[[1u8, 1], [1, 1]]]]),
1176 rctensor0(0u32),
1177 rctensor0(1u8),
1178 rctensor0(1.0f32),
1179 rctensor0(0u8),
1180 rctensor0(1.0f32),
1181 rctensor0(0i32),
1182 rctensor0(1.0f32),
1183 );
1184 let input = input.into_iter().map(IntoTValue::into_tvalue).collect::<TVec<_>>();
1185 let output = op.eval(input).unwrap();
1186 assert_eq!(*output[0], tensor4(&[[[[8i32, 12], [20, 24]]]]));
1187 }
1188
1189 #[test]
1190 fn valid_conv_absorbs_precursor_pad() -> TractResult<()> {
1191 let mut model = TypedModel::default();
1192 let wire = tvec!(model.add_source("source", f32::fact(dims!(1, 10)))?);
1193 let wire = model.wire_node(
1194 "pad",
1195 Pad {
1196 pads: vec![(0, 0), (1, 0)],
1197 mode: ops::array::PadMode::Constant(rctensor0(0f32)),
1198 },
1199 &wire,
1200 )?;
1201 let kernel = model.add_const("kernel", rctensor3(&[[[1f32, 2f32]]]))?;
1202 let bias = model.add_const("bias", rctensor0(0f32))?;
1203 let wire = model.wire_node(
1204 "conv",
1205 Conv {
1206 pool_spec: PoolSpec {
1207 data_format: crate::ops::nn::DataFormat::CHW,
1208 dilations: None,
1209 strides: None,
1210 kernel_shape: tvec![2],
1211 padding: Explicit(tvec![0], tvec![0]),
1212 input_channels: 1,
1213 output_channels: 1,
1214 },
1215 kernel_fmt: crate::ops::cnn::KernelFormat::OIHW,
1216 group: 1,
1217 q_params: None,
1218 },
1219 &[wire[0], kernel, bias],
1220 )?;
1221 model.set_output_outlets(&wire)?;
1222 model.declutter()?;
1223 assert_eq!(model.nodes().len(), 4); let cv = model.nodes()[3].op_as::<Conv>().unwrap();
1225 assert_eq!(cv.pool_spec.padding, Explicit(tvec![1], tvec![0])); Ok(())
1227 }
1228}