1use tract_data::itertools::izip;
2use tract_linalg::WeightType;
3use tract_linalg::block_quant::{BlockQuantFact, PackedBlockQuantFormat};
4use tract_num_traits::Zero;
5
6use crate::internal::*;
7use crate::model::*;
8use crate::ops;
9use crate::ops::array::Pad;
10use crate::ops::array::PadMode;
11use crate::ops::binary::TypedBinOp;
12use crate::ops::cast::cast;
13use crate::ops::cnn::PaddingSpec::*;
14use crate::ops::cnn::conv::block_quant::{BlockQuantIntoShape, SplitGroupBlockQuant};
15use crate::ops::cnn::conv::lazy_im2col::LazyIm2Col;
16use crate::ops::cnn::conv::lazy_im2col::LazyIm2colParams;
17use crate::ops::cnn::wire_reshape_bias_for_bin;
18use crate::ops::einsum::EinSum;
19use crate::ops::math::{Add, Div, Mul, Sub};
20use crate::ops::math::{add, div, mul, sub};
21use crate::ops::matmul::ModePicker;
22use crate::ops::matmul::optimized::AddMatMulGeometry;
23use crate::ops::matmul::optimized::MapOutputAxisToInput;
24use crate::ops::matmul::pack::{OptMatMulPack, OptSimpleMatMulPack};
25use crate::ops::matmul::quant::wire_ensure_q8_flavour;
26use crate::ops::nn::Reduce;
27
28use super::depth_wise::DepthWise;
29use super::im2col::Im2Col;
30use crate::ops::cnn::conv::KernelFormat;
31use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry, PoolSpec};
32use crate::ops::matmul::optimized::{OptMatMul, ProtoFusedSpec};
33use crate::ops::nn::{BaseDataShape, DataFormat, DataShape};
34
35use tract_linalg::mmm::{MMMInputFormat, MatMatMul};
36use tract_linalg::pack::{PackedFormat, PackedI8K4};
37
38#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
39pub struct Conv {
40 pub pool_spec: PoolSpec,
41 pub kernel_fmt: KernelFormat,
42 pub group: usize,
43 pub q_params: Option<DatumType>,
48}
49
50impl Conv {
51 pub fn input_channels(&self) -> usize {
52 self.pool_spec.input_channels
53 }
54
55 pub fn output_channels(&self) -> usize {
56 self.pool_spec.output_channels
57 }
58
59 pub fn wire_kernel_as_g_o_ihw(
60 &self,
61 model: &mut TypedModel,
62 name: &str,
63 mut kernel: OutletId,
64 ) -> TractResult<TVec<OutletId>> {
65 let fact = model.outlet_fact(kernel)?;
66 if fact.is_exotic() {
67 ensure!(self.kernel_fmt == KernelFormat::OIHW && fact.rank() >= 2);
68 kernel = model.wire_node(
69 format!("{name}.prep_kernel.g"),
70 SplitGroupBlockQuant { group: self.group },
71 &[kernel],
72 )?[0];
73 kernel = model.wire_node(
74 format!("{name}.prep_kernel.ihw"),
75 BlockQuantIntoShape {
76 shape: tvec!(
77 self.output_channels() / self.group,
78 self.input_channels() / self.group
79 * self.pool_spec.kernel_shape.iter().product::<usize>(),
80 ),
81 },
82 &[kernel],
83 )?[0];
84 Ok(tvec!(kernel))
85 } else {
86 for (ix, op) in self
87 .kernel_fmt
88 .kernel_as_group_o_ihw_ops(&fact.shape, self.group)
89 .into_iter()
90 .enumerate()
91 {
92 kernel = model.wire_node(format!("{name}.prep_kernel.{ix}"), op, &[kernel])?[0];
93 }
94 Ok(tvec!(kernel))
95 }
96 }
97
98 fn wire_pack_g_o_ihw(
99 &self,
100 model: &mut TypedModel,
101 name: &str,
102 format: &dyn MMMInputFormat,
103 kernel: OutletId,
104 ) -> TractResult<OutletId> {
105 let fact = model.outlet_fact(kernel)?;
106 let wire = if fact.is_exotic() {
107 let fact = model
108 .outlet_fact(kernel)?
109 .exotic_fact
110 .as_ref()
111 .and_then(|of| of.downcast_ref::<BlockQuantFact>())
112 .context("Only manage BlockQuant")?;
113 model.wire_node(
114 format!("{name}.prep_kernel.pack"),
115 OptSimpleMatMulPack {
116 packed_format: format
117 .downcast_ref::<PackedBlockQuantFormat>()
118 .context("Expect a block quant format")?
119 .clone(),
120 k: fact.k(),
121 m: fact.m(),
122 },
123 &[kernel],
124 )?
125 } else {
126 model.wire_node(
128 format!("{name}.prep_kernel.pack"),
129 OptMatMulPack {
130 packers: vec![dyn_clone::clone_box(format)],
131 k_axis: 2,
132 mn_axis: 1,
133 mode_picker: ModePicker::Single,
134 },
135 &[kernel],
136 )?
137 };
138 Ok(wire[0])
139 }
140
141 fn wire_bias_as_non_linear(
143 &self,
144 model: &mut TypedModel,
145 name: &str,
146 bias: OutletId,
147 c_group_axis: usize,
148 ) -> TractResult<(ProtoFusedSpec, OutletId)> {
149 use tract_linalg::BinOp::Add;
150 let fact = model.outlet_fact(bias)?;
151 if fact.shape.volume().is_one() {
152 Ok((ProtoFusedSpec::BinScalar(2, Add), bias))
153 } else {
154 let bias = AxisOp::wire_split_axis(
155 model,
156 format!("{name}.reformat_bias"),
157 bias,
158 0,
159 self.group,
160 )?[0];
161 let pfs =
162 ProtoFusedSpec::BinPerRow(2, Add, MapOutputAxisToInput(tvec!((c_group_axis, 0))));
163 Ok((pfs, bias))
164 }
165 }
166
167 pub unsafe fn wire_as_quant_im2col(
168 &self,
169 model: &mut TypedModel,
170 name: &str,
171 wires: &[OutletId],
172 ) -> TractResult<TVec<OutletId>> {
173 ensure!(self.q_params.is_some());
174 use crate::ops::matmul::quant as qmm;
175
176 let c_dt = self.q_params.unwrap();
177 let &[mut x, mut kernel, bias, mut x0, x_scale, mut k0, mut k_scale, y0, y_scale] = wires
178 else {
179 bail!("Wrong number of inputs")
180 };
181 wire_ensure_q8_flavour(model, name, &mut kernel, "k", &mut k0, i8::datum_type())?;
182 wire_ensure_q8_flavour(model, name, &mut x, "x", &mut x0, i8::datum_type())?;
183
184 let a_fact = model.outlet_fact(kernel)?.clone();
185 let b_fact = model.outlet_fact(x)?.clone();
186
187 let (_geo, m, k, n) = self.compute_geo(&b_fact)?;
188 let (mmm, packing) = self.choose_impl(&b_fact, &a_fact, m, k, &n)?;
189 let output_shape = self.pool_spec.output_shape(&b_fact.shape)?;
190
191 if !model.outlet_fact(k_scale)?.shape.volume().is_one() {
192 if !output_shape.fmt.c_is_last() {
195 k_scale = model.wire_node(
196 format!("{name}.a_scale_axis_fix"),
197 AxisOp::Add(1),
198 &[k_scale],
199 )?[0];
200 }
201 }
202
203 let abc_scale = qmm::combine_scales(model, name, k_scale, x_scale, y_scale)?;
204
205 let im2col = model.wire_node(
206 format!("{name}.im2col"),
207 Im2Col::new(
208 self.pool_spec.clone(),
209 self.group,
210 k,
211 &b_fact.shape,
212 mmm.clone(),
213 packing,
214 )?,
215 &[x, x0],
216 )?[0];
217
218 let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
219 let g_o_ihw_as_i32 =
220 model.wire_node(format!("{name}.kernel_as_i32"), cast(i32::datum_type()), &g_o_ihw)?;
221 let sum_ker_g_c_k = model.wire_node(
222 format!("{name}.sum_ker_g_c_k"),
223 Reduce::new(tvec!(2), ops::nn::Reducer::Sum),
224 &g_o_ihw_as_i32,
225 )?;
226 let sum_ker_a_g_c =
227 model.wire_node(format!("{name}.rm_k"), AxisOp::Rm(2), &sum_ker_g_c_k)?;
228 let sum_ker_n_g_c = model.wire_node(
230 format!("{name}.sum_ker_n_g_c.axis_0"),
231 AxisOp::Add(0),
232 &sum_ker_a_g_c,
233 )?;
234 let hw_position = if self.pool_spec.data_format.c_is_last() { 1 } else { 3 };
235 let sum_ker = model.wire_node(
236 format!("{name}.sum_ker_n_g_c"),
237 AxisOp::Add(hw_position),
238 &sum_ker_n_g_c,
239 )?;
240
241 ensure!(
242 mmm.packings()[packing].1.downcast_ref::<PackedFormat>().is_some()
243 || mmm.packings()[packing].1.downcast_ref::<PackedI8K4>().is_some(),
244 "Im2Col/QSumB support PackedFormat or PackedI8K4 activation packings"
245 );
246 let mut sum_x = model.wire_node(
247 format!("{name}.sum_x"),
248 super::QSumB { dt: b_fact.datum_type, n, r: mmm.nr(), k },
249 &[im2col],
250 )?;
251 sum_x = model.wire_node(format!("{name}.add_c"), AxisOp::Add(2), &sum_x)?;
253 if self.pool_spec.data_format.c_is_last() {
254 sum_x =
255 model.wire_node(format!("{name}.transpose_sum_b"), AxisOp::Move(3, 1), &sum_x)?;
256 }
257
258 let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?;
259 let bias_name = &model.node(bias.node).name;
260 let bias =
261 model.wire_node(format!("{bias_name}.cast"), cast(mmm.internal_type()), &[bias])?[0];
262 let wire = self.wire_mm_weights_bias(
263 model,
264 name,
265 im2col,
266 g_o_ihw[0],
267 bias,
268 mmm,
269 packing,
270 i32::datum_type(),
271 mmm_output_shape.clone().into(),
272 k,
273 c_axis,
274 h_axis,
275 )?;
276
277 let wire = qmm::compensate_zero_points(
278 model,
279 name,
280 wire[0],
281 k.to_dim(),
282 k0,
283 x0,
284 sum_ker[0],
285 sum_x[0],
286 )?;
287
288 let wire = self.wire_remove_group(model, name, &[wire], &mmm_output_shape, c_axis)?;
289 let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
290 let wire = qmm::requant(model, name, wire[0], c_dt, abc_scale, y0)?;
291 Self::wire_geo_reshape(model, name, &[wire], &output_shape)
292 }
293
294 pub fn wire_remove_group<D: DimLike>(
295 &self,
296 model: &mut TypedModel,
297 name: &str,
298 wire: &[OutletId],
299 mmm_output_shape: &[D],
300 c_axis: usize,
301 ) -> TractResult<TVec<OutletId>> {
302 let m = &mmm_output_shape[c_axis];
303 let op = if self.group == 1 {
304 AxisOp::Rm(c_axis - 1)
305 } else {
306 AxisOp::Reshape(
307 c_axis - 1,
308 tvec!(self.group.to_dim(), m.to_dim()),
309 tvec!(m.to_dim() * self.group),
310 )
311 };
312 model.wire_node(format!("{name}.reshape_group"), op, wire)
313 }
314
315 pub unsafe fn wire_as_im2col_pair(
316 &self,
317 model: &mut TypedModel,
318 name: &str,
319 wire: &[OutletId],
320 ) -> TractResult<TVec<OutletId>> {
321 let &[x, w, bias] = wire else { bail!("Wrong number of inputs") };
322 let x_fact = model.outlet_fact(x)?.clone();
323 let w_fact = model.outlet_fact(w)?.clone();
324 let c_dt = crate::ops::matmul::output_type(x_fact.datum_type);
325
326 let (_, m, k, n) = self.compute_geo(&x_fact)?;
327 let (mmm, packing) = self.choose_impl(&x_fact, &w_fact, m, k, &n)?;
328 let geo_output_shape = self.pool_spec.output_shape(&x_fact.shape)?;
329 let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo_output_shape)?;
330
331 let padding =
332 model.add_const(format!("{name}.b0"), Tensor::zero_scalar_dt(x_fact.datum_type)?)?;
333
334 let mut wire: TVec<_> = wire.into();
335 wire[0] = model.wire_node(
336 format!("{name}.im2col"),
337 Im2Col::new(
338 self.pool_spec.clone(),
339 self.group,
340 k,
341 &x_fact.shape,
342 mmm.clone(),
343 packing,
344 )?,
345 &[wire[0], padding],
346 )?[0];
347
348 let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, wire[1])?;
349
350 let wire = self
351 .wire_mm_weights_bias(
352 model,
353 name,
354 wire[0],
355 g_o_ihw[0],
356 bias,
357 mmm,
358 packing,
359 c_dt,
360 mmm_output_shape.clone().into(),
361 k.to_usize().unwrap(),
362 c_axis,
363 h_axis,
364 )
365 .context("in wire_opt_matmul")?;
366
367 let wire = self.wire_remove_group(model, name, &wire, &mmm_output_shape, c_axis)?;
368 let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
369 Self::wire_geo_reshape(model, name, &wire, &geo_output_shape)
370 }
371
372 fn mmm_output_shape<D: DimLike>(
374 &self,
375 output_shape: &BaseDataShape<D, TVec<D>>,
376 ) -> TractResult<(TVec<D>, usize, usize)> {
377 let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
378 let shape: BaseDataShape<D, TVec<D>> = output_shape.fmt.with_n().from_n_c_hw(
379 output_shape.n().cloned().unwrap_or_else(|| 1.into()),
380 output_shape.c().clone(),
381 tvec!(geo_collapsed_out),
382 )?;
383 let mut mmm_output_shape: TVec<D> = shape.shape.clone();
384 let mut c_axis = shape.c_axis();
385 let mut h_axis = shape.h_axis();
386 mmm_output_shape[shape.c_axis()] = mmm_output_shape[c_axis].clone() / self.group;
387 mmm_output_shape.insert(c_axis, self.group.into());
388 if h_axis > c_axis {
389 h_axis += 1;
390 }
391 c_axis += 1;
392 Ok((mmm_output_shape, c_axis, h_axis))
393 }
394
395 fn wire_rm_n_if_needed(
396 &self,
397 model: &mut TypedModel,
398 name: &str,
399 wire: &[OutletId],
400 ) -> TractResult<TVec<OutletId>> {
401 if self.pool_spec.data_format.has_n() {
402 Ok(wire.into())
403 } else {
404 model.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), wire)
405 }
406 }
407
408 fn wire_geo_reshape<D: DimLike>(
409 model: &mut TypedModel,
410 name: &str,
411 wire: &[OutletId],
412 output_shape: &BaseDataShape<D, TVec<D>>,
413 ) -> TractResult<TVec<OutletId>> {
414 let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
415 model
416 .wire_node(
417 name,
418 AxisOp::Reshape(
419 output_shape.h_axis(),
420 tvec!(geo_collapsed_out.to_dim()),
421 output_shape.hw_dims().iter().map(|d| d.to_dim()).collect(),
422 ),
423 wire,
424 )
425 .context("in wire_geo_reshape")
426 }
427
428 pub unsafe fn wire_as_lazy_im2col(
429 &self,
430 model: &mut TypedModel,
431 name: &str,
432 wire: &[OutletId],
433 ) -> TractResult<TVec<OutletId>> {
434 let &[mut x, kernel, bias] = wire else { bail!("Wrong number of inputs") };
435 let mut x_fact = model.outlet_fact(x)?.clone();
436 let w_fact = model.outlet_fact(kernel)?.clone();
437 let (geo, m, k, n) = self.compute_geo(&x_fact)?;
438 let (mmm, packing) = self.choose_impl(&x_fact, &w_fact, m, k, &n)?;
439 debug!("{name} as lazy_im2col: m={m} k={k} n={n} {mmm:?}");
440 let input_shape = x_fact.shape.as_concrete().unwrap().to_vec();
441 let mut geo = geo.to_concrete(&input_shape)?.into_owned();
442 let mut input_shape: DataShape = self.pool_spec.data_format.shape(input_shape.into())?;
443 let padding = self.pool_spec.computed_padding(input_shape.hw_dims());
444 if padding.iter().any(|axis| axis.pad_before != 0 || axis.pad_after != 0) {
445 let mut pads = vec![(0, 0); x_fact.rank()];
446 for (ix, ax) in padding.iter().enumerate() {
447 pads[input_shape.h_axis() + ix] = (ax.pad_before, ax.pad_after);
448 }
449 let op = crate::ops::array::Pad {
450 mode: crate::ops::array::PadMode::Constant(
451 Tensor::zero_scalar_dt(x_fact.datum_type)?.into_arc_tensor(),
452 ),
453 pads,
454 };
455 x = model.wire_node(format!("{name}.pad"), op, &[x])?[0];
456 let valid_pool_spec = PoolSpec { padding: Valid, ..self.pool_spec.clone() };
457 x_fact = model.outlet_fact(x)?.clone();
458 let concrete_shape = x_fact.shape.as_concrete().unwrap();
459 input_shape = valid_pool_spec.data_format.shape(concrete_shape.into())?;
460 geo = valid_pool_spec
461 .compute_geo(&x_fact.shape)?
462 .to_concrete(concrete_shape)?
463 .into_owned();
464 }
465 let c_dt = crate::ops::matmul::output_type(x_fact.datum_type);
466 let c_stride = input_shape.c_stride();
467 let size_of_b = x_fact.datum_type.size_of() as isize;
468 let n_byte_offsets: Vec<isize> =
469 geo.patch.centers_offsets().into_iter().map(|x| x * size_of_b).collect();
470 let ci_per_group = self.input_channels() / self.group;
473 let k_byte_offsets: Vec<isize> = (0..ci_per_group)
474 .flat_map(|ici| {
475 geo.patch
476 .standard_layout_data_field
477 .iter()
478 .map(move |x| (x + (ici * c_stride) as isize) * size_of_b)
479 })
480 .collect();
481 let group_stride_bytes = (ci_per_group * c_stride) as isize * size_of_b;
482 let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo.output_shape)?;
483 let packer = mmm.packings()[packing]
484 .1
485 .downcast_ref::<PackedFormat>()
486 .with_context(|| {
487 format_err!(
488 "Quand Im2Col expects regular packed format, got {:?}",
489 mmm.packings()[packing].1
490 )
491 })?
492 .clone();
493 let params = LazyIm2colParams { packer, n_byte_offsets, k_byte_offsets };
494 let x = model.wire_node(
495 format!("{name}.lazyIm2col"),
496 LazyIm2Col { params: Arc::new(params), group: self.group, group_stride_bytes },
497 &[x],
498 )?[0];
499
500 let kernel = self.wire_kernel_as_g_o_ihw(model, name, kernel)?[0];
501 let wire = self.wire_mm_weights_bias(
502 model,
503 name,
504 x,
505 kernel,
506 bias,
507 mmm,
508 packing,
509 c_dt,
510 mmm_output_shape.clone().into(),
511 k,
512 c_axis,
513 h_axis,
514 )?;
515
516 let wire = self.wire_remove_group(model, name, &wire, &mmm_output_shape, c_axis)?;
517 let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
518 Self::wire_geo_reshape(model, name, &wire, &geo.output_shape)
519 }
520
521 #[allow(clippy::type_complexity)]
522 fn compute_geo(
523 &self,
524 input_fact: &TypedFact,
525 ) -> TractResult<(PoolGeometry, usize, usize, TDim)> {
526 let geo = self.pool_spec.compute_geo(&input_fact.shape)?;
527
528 trace!("output channels: {:?}", self.output_channels());
529 let m = self.output_channels() / self.group;
530 let k = self.input_channels() * self.pool_spec.kernel_shape.iter().product::<usize>()
531 / self.group;
532 let n: TDim =
533 self.pool_spec.output_shape(&input_fact.shape)?.hw_dims().iter().cloned().product();
534 Ok((geo, m, k, n))
535 }
536
537 fn choose_impl(
538 &self,
539 input_fact: &TypedFact,
540 weight_fact: &TypedFact,
541 m: usize,
542 k: usize,
543 n: &TDim,
544 ) -> TractResult<(Box<dyn MatMatMul>, usize)> {
545 let w_dt = weight_fact.datum_type;
546 let x_dt = input_fact.datum_type;
547
548 let acc = if x_dt.is_float() { x_dt } else { i32::datum_type() };
549 if weight_fact.is_exotic() {
550 let bqf = weight_fact
551 .exotic_fact
552 .as_ref()
553 .and_then(|of| of.downcast_ref::<BlockQuantFact>())
554 .unwrap();
555 let weight_type = WeightType::BlockQuant(bqf.format.clone());
556 tract_linalg::ops()
557 .mmm_impls()
558 .iter()
559 .filter(|mmm| mmm.internal_type() == acc)
560 .flat_map(|mmm| {
561 mmm.packings().iter().enumerate().map(move |(ix, p)| (mmm, ix, &p.0, &p.1))
562 })
563 .filter(|(_, _, pa, pb)| {
564 pb.precursor() == x_dt.into() && pa.precursor() == weight_type
565 })
566 .map(|(mmm, p, _, _)| (mmm.clone(), p))
567 .min_by_key(|(mmm, _)| {
568 mmm.quality().cost() as isize * 1000 - (mmm.mr() * mmm.nr()) as isize
569 })
570 .context("Not matmu found")
571 } else {
572 let mmm = tract_linalg::ops()
573 .mmm(acc, Some(m), Some(k), n.to_usize().ok())
574 .context("No matmul found")?;
575 let packing = mmm
576 .packings()
577 .iter()
578 .position(|p| {
579 p.0.precursor() == w_dt.unquantized().into()
580 && p.1.precursor() == x_dt.unquantized().into()
581 })
582 .context("No packing found")?;
583 Ok((mmm, packing))
584 }
585 }
586
587 #[allow(clippy::too_many_arguments)]
588 fn wire_mm_weights_bias(
589 &self,
590 model: &mut TypedModel,
591 name: &str,
592 input: OutletId,
593 g_o_ihw: OutletId,
594 bias: OutletId,
595 mmm: Box<dyn MatMatMul>,
596 packing: usize,
597 c_datum_type: DatumType,
598 mmm_output_shape: ShapeFact,
599 k: usize,
600 c_m_axis: usize,
601 c_n_axis: usize,
602 ) -> TractResult<TVec<OutletId>> {
603 ensure!(model.outlet_fact(bias)?.datum_type == mmm.internal_type());
604 let a_pack = &mmm.packings()[packing].0;
605 let packed_ker = self
606 .wire_pack_g_o_ihw(model, name, &**a_pack, g_o_ihw)
607 .context("in kernel_as_packed_as")?;
608 let (mut c_to_a_axis_mapping, mut c_to_b_axis_mapping) = (tvec!(), tvec!());
609
610 c_to_a_axis_mapping.push((c_m_axis - 1, 0)); c_to_b_axis_mapping.push((0, 0)); c_to_b_axis_mapping.push((c_m_axis - 1, 1)); let geo = AddMatMulGeometry {
615 k: k.to_dim(),
616 c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
617 c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
618 };
619 let mut ops: Vec<ProtoFusedSpec> =
620 vec![ProtoFusedSpec::AddMatMul { geo, a: 1, b: 0, packings: vec![(packing, None)] }];
621 let mut wires: TVec<OutletId> = tvec!(input, packed_ker);
622 let bias_fact = model.outlet_fact(bias)?;
623 if bias_fact.konst.is_none() || !bias_fact.konst.as_ref().unwrap().is_all_zero()? {
624 let (fused, bias) = self.wire_bias_as_non_linear(model, name, bias, c_m_axis - 1)?;
625 wires.push(bias);
626 ops.push(fused);
627 }
628 ops.push(ProtoFusedSpec::Store(vec![unsafe {
629 mmm.c_view(Some(c_m_axis), Some(c_n_axis))
630 }]));
631 model.wire_node(
632 format!("{name}.matmatmul"),
633 OptMatMul::new(
634 vec![mmm],
635 ModePicker::Single,
636 c_datum_type.fact(mmm_output_shape),
637 Some(c_m_axis),
638 Some(c_n_axis),
639 ops,
640 packing == 0 && self.group == 1,
641 )?,
642 &wires,
643 )
644 }
645
646 pub fn wire_as_depth_wise(
647 &self,
648 model: &mut TypedModel,
649 name: &str,
650 wire: &[OutletId],
651 ) -> TractResult<OutletId> {
652 let &[x, kernel, mut bias] = wire else { bail!("Wrong number of inputs") };
653 let x_fact = model.outlet_fact(x)?.clone();
654 let x_shape = x_fact.shape.as_concrete().unwrap();
655 let ConcretePoolGeometry { input_shape, patch, output_shape } =
656 self.pool_spec.compute_geo(&x_fact.shape)?.to_concrete(x_shape)?.into_owned();
657 let kernel = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
658 let c_axis = self.pool_spec.data_format.shape(x_shape)?.c_axis();
659 bias = wire_reshape_bias_for_bin(
660 model,
661 name,
662 bias,
663 x_fact.rank(),
664 c_axis,
665 self.output_channels(),
666 )?[0];
667 let op = DepthWise::new(patch, input_shape, output_shape);
668 Ok(model.wire_node(name, op, &[x, kernel[0], bias])?[0])
669 }
670
671 fn try_blocked_conv(&self, input_fact: &TypedFact) -> Option<super::BlockedConv> {
677 let enabled = if cfg!(target_family = "wasm") {
682 std::env::var("TRACT_DISABLE_BLOCKED_CONV").is_err()
683 } else {
684 std::env::var("TRACT_ENABLE_BLOCKED_CONV").is_ok()
685 };
686 if !enabled {
687 return None;
688 }
689 if self.q_params.is_some() {
690 return None;
691 }
692 if input_fact.datum_type != f32::datum_type() {
693 return None;
694 }
695 if self.pool_spec.data_format != crate::ops::nn::DataFormat::NCHW {
696 return None;
697 }
698 if self.pool_spec.rank() != 2 || self.pool_spec.kernel_shape[1] != 1 {
699 return None;
700 }
701 if self.pool_spec.stride(1) != 1 || self.pool_spec.dilation(1) != 1 {
702 return None;
703 }
704 let group = self.group;
705 let oc = self.output_channels();
706 let c_in = self.input_channels();
707 if group == 0 || !oc.is_multiple_of(group) || !c_in.is_multiple_of(group) {
708 return None;
709 }
710 let ocg = oc / group;
711 if ocg == 0 || ocg > 8 {
714 return None;
715 }
716 let concrete = input_fact.shape.as_concrete()?;
717 let shape = self.pool_spec.data_format.shape(concrete).ok()?;
718 let h_axis = shape.h_axis();
719 let h_in = concrete[h_axis];
720 let w = concrete[h_axis + 1];
721 let pads = self.pool_spec.computed_padding(shape.hw_dims());
722 Some(super::BlockedConv {
723 n: *shape.n().unwrap_or(&1),
724 c_in,
725 h_in,
726 w,
727 oc,
728 group,
729 kh: self.pool_spec.kernel_shape[0],
730 stride_h: self.pool_spec.stride(0),
731 dil_h: self.pool_spec.dilation(0),
732 pad_before_h: pads[0].pad_before,
733 h_out: pads[0].convoluted,
734 })
735 }
736
737 fn wire_as_blocked_conv(
738 &self,
739 model: &mut TypedModel,
740 name: &str,
741 wire: &[OutletId],
742 op: super::BlockedConv,
743 ) -> TractResult<OutletId> {
744 let &[x, kernel, bias] = wire else { bail!("Wrong number of inputs") };
745 let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
748 Ok(model.wire_node(name, op, &[x, g_o_ihw[0], bias])?[0])
749 }
750
751 fn declutter_stride_slice_to_downsample(
752 &self,
753 model: &TypedModel,
754 node: &TypedNode,
755 ) -> TractResult<Option<TypedModelPatch>> {
756 let spatial_rank = self.pool_spec.rank();
757 if let Some(axis) = (0..spatial_rank).find(|&ax| {
758 self.pool_spec.stride(ax) > 1
759 && self.pool_spec.padding.valid_dim(ax, self.pool_spec.stride(ax) == 1)
760 && (self.pool_spec.kernel_shape[ax] == 1
761 || self.pool_spec.dilation(ax).is_multiple_of(self.pool_spec.stride(ax)))
762 }) {
763 let input_fact = model.outlet_fact(node.inputs[0])?;
764 let downsample_factor = self.pool_spec.stride(axis);
765 let mut new_op = self.clone();
766 if new_op.pool_spec.dilation(axis) > 1 {
767 new_op.pool_spec.dilations.as_mut().unwrap()[axis] =
768 new_op.pool_spec.dilations.as_mut().unwrap()[axis].divceil(downsample_factor);
769 }
770 new_op.pool_spec.strides.as_mut().unwrap()[axis] /= downsample_factor;
771 let mut patch = TypedModelPatch::default();
772 let mut taps = patch.taps(model, &node.inputs)?;
773 let shape = self.pool_spec.data_format.shape(&input_fact.shape)?;
774 taps[0] = patch.wire_node(
775 format!("{}.downsample.{}", node.name, axis),
776 crate::ops::Downsample::new(axis + shape.h_axis(), downsample_factor as isize, 0),
777 &[taps[0]],
778 )?[0];
779 let id = patch.wire_node(&*node.name, new_op, &taps)?[0];
780 patch.shunt_outside(model, OutletId::new(node.id, 0), id)?;
781 return Ok(Some(patch));
782 }
783 Ok(None)
784 }
785
786 fn declutter_as_einsum(
787 &self,
788 model: &TypedModel,
789 node: &TypedNode,
790 ) -> TractResult<Option<TypedModelPatch>> {
791 let (input_facts, output_facts) = model.node_facts(node.id)?;
792 let full_input_shape = input_facts[0].shape.to_tvec();
793 let input_shape = self.pool_spec.data_format.shape(&full_input_shape)?;
794 if self.group == 1
795 && self.pool_spec.strides().iter().all(|s| *s == 1)
796 && self.pool_spec.dilations().iter().all(|d| *d == 1)
797 && self.pool_spec.kernel_shape.iter().product::<usize>() == 1
798 && self
799 .pool_spec
800 .computed_padding(input_shape.hw_dims())
801 .iter()
802 .all(|pad| pad.pad_after.is_zero() && pad.pad_before.is_zero())
803 {
804 let mut axes = self.axes_mapping(&input_facts, &output_facts)?;
805 let mut patch = TypedModelPatch::new("declutter_as_einsum");
806 let mut taps = patch.taps(model, &node.inputs)?;
807 let name = &node.name;
808 let co = self.output_channels();
809 taps[1] =
810 self.wire_kernel_as_g_o_ihw(&mut patch, &format!("{name}.filters"), taps[1])?[0];
811 taps[1] =
812 patch.wire_node(format!("{name}.filters_as_co_ci"), AxisOp::Rm(0), &[taps[1]])?[0];
813
814 while axes.rank(InOut::In(1)) > 0 {
815 axes = axes.remove_axis_occurency(InOut::In(1), 0)?;
816 }
817 axes = axes
818 .with_extra_axis_occurency('O', InOut::In(1), 0)?
819 .with_extra_axis_occurency('I', InOut::In(1), 1)?;
820
821 let bias_fact = input_facts[2];
822 let wire = if self.q_params.is_some() {
823 if bias_fact.rank() == 1 {
824 axes = axes.linking('O', (InOut::In(2), 0))?;
825 }
826 let op = EinSum { axes, operating_dt: i32::datum_type(), q_params: self.q_params };
827 patch.wire_node(format!("{name}.einsum"), op, &taps)?[0]
828 } else {
829 axes = axes.remove_slot(InOut::In(2))?;
830 let op = EinSum { axes, operating_dt: input_facts[0].datum_type, q_params: None };
831 let mut wire = patch.wire_node(format!("{name}.einsum"), op, &taps[0..2])?[0];
832
833 if !bias_fact.konst.as_ref().map(|f| f.is_zero()).transpose()?.unwrap_or(false) {
834 let bias_current_shape =
835 if bias_fact.rank() == 0 { tvec!() } else { tvec!(co.to_dim()) };
836 let mut bias_shape = tvec!(1.to_dim(); input_shape.rank());
837 if bias_fact.rank() > 0 {
838 bias_shape[input_shape.c_axis()] = co.to_dim();
839 }
840 let b = patch.wire_node(
841 format!("{name}.bias.reshape"),
842 AxisOp::Reshape(0, bias_current_shape, bias_shape),
843 &[taps[2]],
844 )?[0];
845 wire = patch.wire_node(
846 format!("{name}.bias"),
847 crate::ops::math::add(),
848 &[wire, b],
849 )?[0];
850 }
851 wire
852 };
853 patch.node_mut(wire.node).name = node.name.to_string();
854 patch.shunt_outside(model, node.id.into(), wire)?;
855 return Ok(Some(patch));
856 }
857 Ok(None)
858 }
859
860 fn declutter_precursor_padding(
861 &self,
862 model: &TypedModel,
863 node: &TypedNode,
864 ) -> TractResult<Option<TypedModelPatch>> {
865 rule_if!(!matches!(
866 self.pool_spec.padding,
867 ExplicitOnnxPool(_, _, _) | SameLower | SameUpper
868 ));
869 let prec = model.node(node.inputs[0].node);
870 rule_if_some!(pad = prec.op_as::<Pad>());
871 rule_if_let!(PadMode::Constant(value) = &pad.mode);
872 let shape = self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
873 rule_if!(value.is_zero()?);
874 rule_if!(pad.pads[shape.c_axis()] == (0, 0));
875 if self.pool_spec.data_format.has_n() {
876 rule_if!(pad.pads[0] == (0, 0));
877 }
878 let mut before: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.0).collect();
879 let mut after: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.1).collect();
880 if let Explicit(bef, aft) = &self.pool_spec.padding {
881 izip!(&mut before, bef).for_each(|(pad, cv)| *pad += cv);
882 izip!(&mut after, aft).for_each(|(pad, cv)| *pad += cv);
883 }
884 let padding = Explicit(before, after);
885 let mut new = self.clone();
886 new.pool_spec.padding = padding;
887 let mut patch = TypedModelPatch::default();
888 let mut wire = patch.taps(model, &node.inputs)?;
889 wire[0] = patch.tap_model(model, prec.inputs[0])?;
890 let wire = patch.wire_node(&node.name, new, &wire)?;
891 patch.shunt_outside(model, node.id.into(), wire[0])?;
892 Ok(Some(patch))
893 }
894
895 fn declutter_channel_arithmetic_succ(
896 &self,
897 model: &TypedModel,
898 node: &TypedNode,
899 ) -> TractResult<Option<TypedModelPatch>> {
900 rule_if!(self.q_params.is_none());
901 rule_if!(self.group == 1);
902 rule_if_let!(&[succ_outlet] = &*node.outputs[0].successors);
903 let succ = model.node(succ_outlet.node);
904 rule_if_some!(bin = succ.op_as::<TypedBinOp>());
905 let other_input = succ.inputs[1 - succ_outlet.slot];
906 let axes_mapping = model.node_axes_mapping(succ.id)?;
907 let input_shape =
908 self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
909 let conv_c_axis = input_shape.c_axis();
910 rule_if!(
911 axes_mapping.axis((InOut::In(succ_outlet.slot), conv_c_axis))?.inputs
912 [1 - succ_outlet.slot]
913 .len()
914 == 1
915 );
916 let mut other_expected_shape = tvec!(1.to_dim(); input_shape.rank());
917 other_expected_shape[conv_c_axis] = self.output_channels().to_dim();
918 rule_if!(*other_expected_shape == *model.outlet_fact(other_input)?.shape);
919
920 let mut patch = TypedModelPatch::default();
921 let [input, mut kernel, mut bias] = *patch.taps(model, &node.inputs)? else {
922 panic!("Expect three inputs");
923 };
924 let name = &node.name;
925 let succ_name = &succ.name;
926
927 let operand = patch.tap_model(model, other_input)?;
928
929 let renamed_bias = format!("{name}.{succ_name}.bias");
930 let renamed_kernel = format!("{name}.{succ_name}.kernel");
931 bias = wire_reshape_bias_for_bin(
932 &mut patch,
933 format!("{renamed_bias}.reshape"),
934 bias,
935 1,
936 0,
937 self.output_channels(),
938 )?[0];
939
940 let operand = wire_reshape_bias_for_bin(
941 &mut patch,
942 format!("{renamed_bias}.reshape_operand"),
943 operand,
944 1,
945 0,
946 self.output_channels(),
947 )?[0];
948
949 let operand_fact = patch.outlet_fact(operand)?.shape.to_tvec();
950 let kernel_fact = patch.outlet_fact(kernel)?;
951 let mut operand_shape_for_kernel = tvec!(1.to_dim(); 2 + input_shape.hw_rank());
952 operand_shape_for_kernel[self.kernel_fmt.o_axis(&kernel_fact.shape)] =
953 self.output_channels().to_dim();
954 let operand_for_kernel = patch.wire_node(
955 format!("{renamed_kernel}.reshape_operand"),
956 AxisOp::Reshape(0, operand_fact, operand_shape_for_kernel),
957 &[operand],
958 )?[0];
959
960 if bin.0.is::<Sub>() && succ_outlet.slot == 0 {
961 bias = patch.wire_node(&renamed_bias, sub(), &[bias, operand])?[0];
962 } else if bin.0.is::<Sub>() {
963 bias = patch.wire_node(&renamed_bias, sub(), &[operand, bias])?[0];
964 } else if bin.0.is::<Div>() && succ_outlet.slot == 0 {
965 bias = patch.wire_node(&renamed_bias, div(), &[bias, operand])?[0];
966 kernel = patch.wire_node(&renamed_kernel, div(), &[kernel, operand_for_kernel])?[0];
967 } else if bin.0.is::<Div>() {
968 bias = patch.wire_node(&renamed_bias, div(), &[operand, bias])?[0];
969 kernel = patch.wire_node(&renamed_kernel, div(), &[operand_for_kernel, kernel])?[0];
970 } else if bin.0.is::<Add>() {
971 bias = patch.wire_node(&renamed_bias, add(), &[bias, operand])?[0];
972 } else if bin.0.is::<Mul>() {
973 bias = patch.wire_node(&renamed_bias, mul(), &[bias, operand])?[0];
974 kernel = patch.wire_node(&renamed_kernel, mul(), &[kernel, operand_for_kernel])?[0];
975 } else {
976 return Ok(None);
977 };
978 let wire = patch.wire_node(&node.name, self.clone(), &[input, kernel, bias])?[0];
979 patch.shunt_outside(model, succ_outlet.node.into(), wire)?;
980 Ok(Some(patch))
981 }
982}
983
984impl Op for Conv {
985 fn name(&self) -> StaticName {
986 "Conv".into()
987 }
988
989 fn info(&self) -> TractResult<Vec<String>> {
990 let mut info = self.pool_spec.info();
991 info.push(format!("Kernel {:?} (groups:{})", self.kernel_fmt, self.group));
992 Ok(info)
993 }
994
995 fn validation(&self) -> Validation {
996 Validation::Rounding
997 }
998
999 op_as_typed_op!();
1000}
1001
1002impl EvalOp for Conv {
1003 fn is_stateless(&self) -> bool {
1004 true
1005 }
1006
1007 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
1008 let mut model = TypedModel::default();
1009 let wire: TVec<OutletId> = inputs
1010 .iter()
1011 .enumerate()
1012 .map(|(ix, v)| model.add_source(format!("source.{ix}"), v.datum_type().fact(v.shape())))
1013 .collect::<TractResult<_>>()?;
1014 let wire = unsafe {
1015 if self.q_params.is_some() {
1016 self.wire_as_quant_im2col(&mut model, "im2col-adhoc", &wire)?
1017 } else {
1018 self.wire_as_im2col_pair(&mut model, "im2col-adhoc", &wire)?
1019 }
1020 };
1021 model.select_output_outlets(&wire)?;
1022 model.into_runnable()?.run(inputs)
1023 }
1024}
1025
1026impl TypedOp for Conv {
1027 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
1028 ensure!(self.q_params.is_some() || inputs[0].datum_type.is_float());
1029 let q_inputs = if self.q_params.is_some() { 6 } else { 0 };
1030 ensure!(inputs[1].datum_type.is_number() || self.kernel_fmt == KernelFormat::OIHW);
1031 if inputs.len() != 3 + q_inputs {
1032 bail!("Wrong number of inputs: expected {} got {}", 3 + q_inputs, inputs.len());
1033 }
1034 if self.q_params.is_some() {
1035 ensure!(inputs[2].datum_type == i32::datum_type());
1036 ensure!(inputs[3].datum_type == i32::datum_type());
1037 ensure!(inputs[4].datum_type.is_float());
1038 ensure!(inputs[5].datum_type == i32::datum_type());
1039 ensure!(inputs[6].datum_type.is_float());
1040 ensure!(inputs[7].datum_type == i32::datum_type());
1041 ensure!(inputs[8].datum_type.is_float());
1042 }
1043 ensure!(self.pool_spec.rank() + 2 == inputs[1].shape.len());
1044 if self.pool_spec.data_format.shape(&*inputs[0].shape)?.c()
1045 != &self.input_channels().to_dim()
1046 {
1047 bail!(
1048 "Inconsistent convolution: input is {:?}, but kernel expects {} input channels.\n{:?}",
1049 inputs[0],
1050 self.input_channels(),
1051 self
1052 );
1053 }
1054 if let ExplicitOnnxPool(bef, after, _) | Explicit(bef, after) = &self.pool_spec.padding {
1055 anyhow::ensure!(bef.len() == self.pool_spec.rank());
1056 anyhow::ensure!(after.len() == self.pool_spec.rank());
1057 }
1058 ensure!(
1059 inputs[2].rank() == 0
1060 || (inputs[2].rank() == 1
1061 && inputs[2].shape.volume() == self.output_channels().to_dim()),
1062 "Bias should be scalar or a vector with one value per output channel. Output channels is {}, bias is {:?}",
1063 self.output_channels(),
1064 inputs[2]
1065 );
1066 let mut fact = self.pool_spec.output_facts(inputs)?.remove(0);
1067 if let Some(dt) = self.q_params {
1068 fact.datum_type = dt;
1069 } else {
1070 ensure!(
1071 inputs[1].is_exotic() || inputs[0].datum_type == inputs[1].datum_type,
1072 "Convolution input, weights and bias must have the same type, got {inputs:?}",
1073 )
1074 }
1075 Ok(tvec!(fact))
1076 }
1077
1078 fn axes_mapping(
1079 &self,
1080 inputs: &[&TypedFact],
1081 outputs: &[&TypedFact],
1082 ) -> TractResult<AxesMapping> {
1083 let fact = &inputs[0];
1084 let shape = self.pool_spec.data_format.shape(&fact.shape)?;
1085 let mut axes = AxesMapping::disconnected(inputs, outputs)?
1086 .renaming((InOut::In(0), shape.c_axis()), 'I')?
1087 .renaming((InOut::Out(0), shape.c_axis()), 'O')?;
1088 if let Some(n_axis) = shape.n_axis() {
1089 axes = axes
1090 .renaming((InOut::In(0), n_axis), 'N')?
1091 .linking('N', (InOut::Out(0), n_axis))?;
1092 }
1093 let h_axis = shape.h_axis();
1094 let geo = "HWXYZ".chars().chain('a'..);
1095 let kernel_spatial_shape = &self.pool_spec.kernel_shape;
1096 let padding = self.pool_spec.computed_padding(shape.hw_dims());
1097 for ((ix, &dim), repr) in kernel_spatial_shape.iter().enumerate().zip(geo) {
1098 if dim == 1
1099 && self.pool_spec.dilation(ix) == 1
1100 && self.pool_spec.stride(ix) == 1
1101 && padding[ix].pad_before.is_zero()
1102 && padding[ix].pad_after.is_zero()
1103 {
1104 axes = axes
1105 .renaming((InOut::In(0), ix + h_axis), repr)?
1106 .linking(repr, (InOut::Out(0), ix + h_axis))?;
1107 }
1108 }
1109 if self.q_params.is_some() {
1110 for (qp_ix, qp) in inputs.iter().enumerate().skip(3) {
1111 if qp.rank() == 1 {
1112 axes = match qp_ix {
1113 3 | 4 => axes.linking('I', (InOut::In(qp_ix), 0))?,
1114 5 | 6 => axes.linking('O', (InOut::In(qp_ix), 0))?,
1115 7 | 8 => axes.linking('O', (InOut::In(qp_ix), 0))?,
1116 _ => unreachable!(),
1117 };
1118 }
1119 }
1120 }
1121 Ok(axes)
1122 }
1123
1124 fn declutter(
1125 &self,
1126 model: &TypedModel,
1127 node: &TypedNode,
1128 ) -> TractResult<Option<TypedModelPatch>> {
1129 macro_rules! pass {
1130 ($func:ident) => {
1131 if let Some(mut r) = self.$func(model, node).context(stringify!($func))? {
1132 trace!(stringify!($func));
1133 r.push_context(stringify!($func));
1134 return Ok(Some(r));
1135 }
1136 };
1137 }
1138 pass!(declutter_stride_slice_to_downsample);
1139 pass!(declutter_as_einsum);
1140 pass!(declutter_channel_arithmetic_succ);
1141 pass!(declutter_precursor_padding);
1142 Ok(None)
1143 }
1144
1145 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
1146 let shape = self.pool_spec.data_format.shape(inputs[0].shape.to_tvec())?;
1147 let kernel_spatial_shape = &self.pool_spec.kernel_shape;
1148 let output_dims = self.pool_spec.padding.compute(
1149 shape.hw_dims(),
1150 kernel_spatial_shape,
1151 &self
1152 .pool_spec
1153 .dilations
1154 .clone()
1155 .unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
1156 &self.pool_spec.strides.clone().unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
1157 );
1158 let n_output_points: TDim =
1159 output_dims.iter().map(|d| d.convoluted.clone()).product::<TDim>();
1160 let n_output_channels = self.output_channels().to_dim();
1161 let kernel_surface = kernel_spatial_shape.iter().product::<usize>().to_dim();
1162 let one = 1.to_dim();
1163 Ok(tvec!((
1164 Cost::FMA(inputs[0].datum_type),
1165 shape.n().cloned().unwrap_or(one)
1166 * shape.c()
1167 * n_output_channels
1168 * n_output_points
1169 * kernel_surface
1170 / self.group
1171 )))
1172 }
1173
1174 fn change_axes(
1175 &self,
1176 model: &TypedModel,
1177 node: &TypedNode,
1178 io: InOut,
1179 change: &AxisOp,
1180 ) -> TractResult<Option<AxisChangeConsequence>> {
1181 rule_if!(io != InOut::In(1));
1182 if io == InOut::In(2)
1183 && let &AxisOp::Rm(_) = change
1184 {
1185 return Ok(Some(AxisChangeConsequence {
1186 substitute_op: Some(Box::new(self.clone())),
1187 wire_changes: tvec!(),
1188 }));
1189 }
1190 let full_input_shape = model.outlet_fact(node.inputs[0])?.shape.to_tvec();
1191 let shape = self.pool_spec.data_format.shape(full_input_shape.clone())?;
1192 if let Some(n) = shape.n_axis() {
1194 assert_eq!(n, 0);
1195 if change == &AxisOp::Rm(n) {
1196 let op = Conv { pool_spec: self.pool_spec.dispose_n_axis(), ..self.clone() };
1197 return Ok(Some(AxisChangeConsequence {
1198 substitute_op: Some(Box::new(op)),
1199 wire_changes: tvec!(
1200 (InOut::In(0), change.clone()),
1201 (InOut::Out(0), change.clone())
1202 ),
1203 }));
1204 }
1205 rule_if!(change.transform_axis(n).map(|axis| axis == 0).unwrap_or(false));
1206 }
1207 let (new_format, axis_move) = match self.pool_spec.data_format {
1209 DataFormat::NCHW => {
1210 (DataFormat::NHWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
1211 }
1212 DataFormat::CHW => {
1213 (DataFormat::HWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
1214 }
1215 DataFormat::NHWC => (DataFormat::NCHW, AxisOp::Move(shape.c_axis(), 1)),
1216 DataFormat::HWC => (DataFormat::CHW, AxisOp::Move(shape.c_axis(), 0)),
1217 };
1218 if *change == axis_move {
1219 let mut new_op = self.clone();
1220 new_op.pool_spec.data_format = new_format;
1221 return Ok(Some(AxisChangeConsequence {
1222 substitute_op: Some(Box::new(new_op)),
1223 wire_changes: tvec!(
1224 (InOut::In(0), change.clone()),
1225 (InOut::Out(0), change.clone())
1226 ),
1227 }));
1228 }
1229 rule_if!(!model.node_input_facts(node.id)?[1].is_exotic());
1231 use AxisOp::*;
1232 let h_axis = shape.h_axis();
1233 let hw_axes = shape.hw_axes();
1234 let kh_axis = self.kernel_fmt.h_axis();
1235 let (geo_adjusted, kernel_adjusted) = match change {
1236 Rm(a)
1237 if hw_axes.contains(a)
1238 && hw_axes.len() > 1
1239 && self.pool_spec.dilation(a - h_axis) == 1
1240 && self.pool_spec.stride(a - h_axis) == 1
1241 && self.pool_spec.kernel_shape[a - h_axis] == 1 =>
1242 {
1243 let geo_axis = a - h_axis;
1244 (Rm(geo_axis), Rm(kh_axis + geo_axis))
1245 }
1246 Add(a) if hw_axes.contains(a) => (Add(a - h_axis), Add(a - h_axis + kh_axis)),
1247 Move(f, t) if hw_axes.contains(f) && hw_axes.contains(t) => {
1248 (Move(f - h_axis, t - h_axis), Move(f - h_axis + kh_axis, t - h_axis + kh_axis))
1249 }
1250 _ => return Ok(None),
1251 };
1252 let pool_spec = self.pool_spec.change_geo_axes(&geo_adjusted)?;
1253 let new_op = Conv { pool_spec, ..self.clone() };
1254 Ok(Some(AxisChangeConsequence {
1255 substitute_op: Some(Box::new(new_op)),
1256 wire_changes: tvec!(
1257 (InOut::In(0), change.clone()),
1258 (InOut::In(1), kernel_adjusted),
1259 (InOut::Out(0), change.clone())
1260 ),
1261 }))
1262 }
1263
1264 fn codegen(
1265 &self,
1266 model: &TypedModel,
1267 node: &TypedNode,
1268 ) -> TractResult<Option<TypedModelPatch>> {
1269 let input_fact = model.outlet_fact(node.inputs[0])?;
1270 unsafe {
1271 if self.q_params.is_some() {
1272 let mut patch = TypedModelPatch::new("quantized-codegen");
1273 let inputs = patch.taps(model, &node.inputs)?;
1274 let wire = self
1275 .wire_as_quant_im2col(&mut patch, &node.name, &inputs)
1276 .context("in wire_as_quant_im2col")?;
1277 patch.shunt_outside(model, node.id.into(), wire[0])?;
1278 patch.obliterate(node.id)?;
1279 Ok(Some(patch))
1280 } else if let Some(op) = self.try_blocked_conv(input_fact) {
1281 let mut patch = TypedModelPatch::new("blocked-conv");
1284 let inputs = patch.taps(model, &node.inputs)?;
1285 let wire = self
1286 .wire_as_blocked_conv(&mut patch, &node.name, &inputs, op)
1287 .context("wire_as_blocked_conv")?;
1288 patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1289 patch.obliterate(node.id)?;
1290 Ok(Some(patch))
1291 } else if input_fact
1292 .shape
1293 .as_concrete()
1294 .map(|s| should_use_lazy(&self.pool_spec, self.group, s, input_fact.datum_type))
1295 .unwrap_or(false)
1296 {
1297 let mut patch = TypedModelPatch::new("lazy-im2col");
1298 let inputs = patch.taps(model, &node.inputs)?;
1299 let wire = self
1300 .wire_as_lazy_im2col(&mut patch, &node.name, &inputs)
1301 .context("wire_as_lazy_im2col")?[0];
1302 patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1303 patch.obliterate(node.id)?;
1304 Ok(Some(patch))
1305 } else if self.group != 1
1306 && self.group == self.output_channels()
1307 && self.group == self.input_channels()
1308 && input_fact.shape.as_concrete().is_some()
1309 {
1310 let mut patch = TypedModelPatch::new("depth_wise");
1311 let inputs = patch.taps(model, &node.inputs)?;
1312 let wire = self
1313 .wire_as_depth_wise(&mut patch, &node.name, &inputs)
1314 .context("wire_as_depth_wise")?;
1315 patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1316 patch.obliterate(node.id)?;
1317 Ok(Some(patch))
1318 } else {
1319 let mut patch = TypedModelPatch::new("im2col");
1320 let inputs = patch.taps(model, &node.inputs)?;
1321 let wire = self
1322 .wire_as_im2col_pair(&mut patch, &node.name, &inputs)
1323 .context("in wire_as_im2col_pair")?[0];
1324 patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
1325 patch.obliterate(node.id)?;
1326 Ok(Some(patch))
1327 }
1328 }
1329 }
1330
1331 as_op!();
1332}
1333
1334const DEFAULT_LAZY_IM2COL_MIN_KERNEL: usize = 6;
1343
1344fn lazy_im2col_min_kernel() -> usize {
1345 use std::sync::OnceLock;
1346 static V: OnceLock<usize> = OnceLock::new();
1347 *V.get_or_init(|| {
1348 std::env::var("TRACT_LAZY_IM2COL_MIN_KERNEL")
1349 .ok()
1350 .and_then(|s| s.parse::<usize>().ok())
1351 .unwrap_or(DEFAULT_LAZY_IM2COL_MIN_KERNEL)
1352 })
1353}
1354
1355#[cfg(target_family = "wasm")]
1373const DEFAULT_LAZY_IM2COL_MAX_EAGER_BYTES: usize = 1024 * 1024;
1374#[cfg(not(target_family = "wasm"))]
1375const DEFAULT_LAZY_IM2COL_MAX_EAGER_BYTES: usize = 4 * 1024 * 1024;
1376
1377fn lazy_im2col_max_eager_bytes() -> usize {
1378 use std::sync::OnceLock;
1379 static V: OnceLock<usize> = OnceLock::new();
1380 *V.get_or_init(|| {
1381 std::env::var("TRACT_LAZY_IM2COL_MAX_EAGER_BYTES")
1382 .ok()
1383 .and_then(|s| s.parse::<usize>().ok())
1384 .unwrap_or(DEFAULT_LAZY_IM2COL_MAX_EAGER_BYTES)
1385 })
1386}
1387
1388fn should_use_lazy(
1389 pool_spec: &PoolSpec,
1390 group: usize,
1391 input_shape: &[usize],
1392 dt: DatumType,
1393) -> bool {
1394 let is_depthwise =
1399 group > 1 && group == pool_spec.input_channels && group == pool_spec.output_channels;
1400 if is_depthwise {
1401 return false;
1402 }
1403 let Ok(output_shape) = pool_spec.output_shape(input_shape) else { return false };
1404 if output_shape.n().unwrap_or(&1) != &1 {
1406 return false;
1407 }
1408 let kernel_volume = pool_spec.kernel_shape.iter().product::<usize>();
1409 if kernel_volume >= lazy_im2col_min_kernel() {
1412 return true;
1413 }
1414 let n: usize = output_shape.hw_dims().iter().product();
1418 let k = pool_spec.input_channels * kernel_volume / group;
1419 let eager_scratch_bytes = k.saturating_mul(n).saturating_mul(dt.size_of());
1420 eager_scratch_bytes >= lazy_im2col_max_eager_bytes()
1421}
1422
1423#[allow(non_snake_case)]
1424#[cfg(test)]
1425mod test {
1426 use super::*;
1427 use crate::ops::array::Pad;
1428 use DataFormat::*;
1429
1430 #[test]
1431 fn onnx_basic_convinteger() {
1432 let op = Conv {
1433 pool_spec: PoolSpec {
1434 data_format: NCHW,
1435 kernel_shape: tvec!(2, 2),
1436 padding: Valid,
1437 dilations: None,
1438 strides: None,
1439 input_channels: 1,
1440 output_channels: 1,
1441 },
1442 kernel_fmt: KernelFormat::OIHW,
1443 group: 1,
1444 q_params: Some(i32::datum_type()),
1445 };
1446 let input = tvec!(
1447 rctensor4(&[[[[1u8, 2, 3], [4, 5, 6], [7, 8, 9]]]]),
1448 rctensor4(&[[[[1u8, 1], [1, 1]]]]),
1449 rctensor0(0u32),
1450 rctensor0(1u8),
1451 rctensor0(1.0f32),
1452 rctensor0(0u8),
1453 rctensor0(1.0f32),
1454 rctensor0(0i32),
1455 rctensor0(1.0f32),
1456 );
1457 let input = input.into_iter().map(IntoTValue::into_tvalue).collect::<TVec<_>>();
1458 let output = op.eval(input).unwrap();
1459 assert_eq!(*output[0], tensor4(&[[[[8i32, 12], [20, 24]]]]));
1460 }
1461
1462 #[test]
1463 fn valid_conv_absorbs_precursor_pad() -> TractResult<()> {
1464 let mut model = TypedModel::default();
1465 let wire = tvec!(model.add_source("source", f32::fact(dims!(1, 10)))?);
1466 let wire = model.wire_node(
1467 "pad",
1468 Pad {
1469 pads: vec![(0, 0), (1, 0)],
1470 mode: ops::array::PadMode::Constant(rctensor0(0f32)),
1471 },
1472 &wire,
1473 )?;
1474 let kernel = model.add_const("kernel", rctensor3(&[[[1f32, 2f32]]]))?;
1475 let bias = model.add_const("bias", rctensor0(0f32))?;
1476 let wire = model.wire_node(
1477 "conv",
1478 Conv {
1479 pool_spec: PoolSpec {
1480 data_format: crate::ops::nn::DataFormat::CHW,
1481 dilations: None,
1482 strides: None,
1483 kernel_shape: tvec![2],
1484 padding: Explicit(tvec![0], tvec![0]),
1485 input_channels: 1,
1486 output_channels: 1,
1487 },
1488 kernel_fmt: crate::ops::cnn::KernelFormat::OIHW,
1489 group: 1,
1490 q_params: None,
1491 },
1492 &[wire[0], kernel, bias],
1493 )?;
1494 model.select_output_outlets(&wire)?;
1495 model.declutter()?;
1496 assert_eq!(model.nodes().len(), 4); let cv = model.nodes()[3].op_as::<Conv>().unwrap();
1498 assert_eq!(cv.pool_spec.padding, Explicit(tvec![1], tvec![0])); Ok(())
1500 }
1501}