1use tract_linalg::mmm::{
2 EagerPackedInput, MMMInputFormat, MMMInputValue, MatMatMul, PackedExoticFact,
3 PackedMatrixStorage,
4};
5use tract_linalg::pack::{PackedFormat, PackedI8K4, PackingWriter};
6
7use crate::internal::*;
8use ndarray::prelude::*;
9use num_integer::Integer;
10
11use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry};
12use crate::ops::cnn::{GeometryBound, PoolSpec, ResolveTo};
13use crate::ops::nn::{BaseDataShape, DataFormat, DataShape};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub struct Im2Col {
17 pub pool_spec: PoolSpec,
18 pub group: usize,
19 geometry: GeometryBound<SymbolicGeometry, ConcreteGeometry>,
20}
21
22#[derive(Debug, Clone, Hash, PartialEq, Eq)]
23struct SymbolicGeometry {
24 group: usize,
25 pool_spec: PoolSpec,
26 pool_geometry: PoolGeometry,
27 out_format: Box<dyn MMMInputFormat>,
29 k: usize,
30}
31
32#[derive(Debug, Clone, Hash, PartialEq, Eq)]
33struct ConcreteGeometry {
34 pool: ConcretePoolGeometry,
35 pub n: usize,
36 k: usize,
37 pub out_format: Box<dyn MMMInputFormat>,
38 pub ci_per_group: usize,
39 patcher: Patcher,
40 input_shape_with_n: DataShape,
41 packed_shape: TVec<usize>, }
43
44impl GeometryBound<SymbolicGeometry, ConcreteGeometry> {
45 pub fn out_format(&self) -> &dyn MMMInputFormat {
46 match self {
47 GeometryBound::Symbolic(s) => &*s.out_format,
48 GeometryBound::Concrete(s) => &*s.out_format,
49 }
50 }
51 pub fn k(&self) -> usize {
52 match self {
53 GeometryBound::Symbolic(s) => s.k,
54 GeometryBound::Concrete(s) => s.k,
55 }
56 }
57}
58
59impl ResolveTo<ConcreteGeometry> for SymbolicGeometry {
60 type Param = [usize];
61 fn resolve(&self, input_full_shape: &[usize]) -> TractResult<ConcreteGeometry> {
62 let pool = self.pool_geometry.to_concrete(input_full_shape)?.into_owned();
63 let patcher = if !pool.patch.padded && pool.patch.rank() == 2 {
64 Patcher::Valid2d
65 } else if pool.patch.rank() == 2 {
66 Patcher::Padded2d
67 } else if !pool.patch.padded && pool.patch.rank() == 1 {
68 Patcher::Valid1d
69 } else {
70 Patcher::Generic
71 };
72 let ci_per_group = pool.input_shape.c_dim() / self.group;
73 let n = pool.output_shape.hw_dims().iter().product();
74 let input_shape_with_n = match self.pool_spec.data_format {
75 DataFormat::HWC => DataFormat::NHWC.from_n_c_hw(
76 1,
77 *pool.input_shape.c(),
78 pool.input_shape.hw_dims(),
79 )?,
80 DataFormat::CHW => DataFormat::NCHW.from_n_c_hw(
81 1,
82 *pool.input_shape.c(),
83 pool.input_shape.hw_dims(),
84 )?,
85 _ => pool.input_shape.clone(),
86 };
87 let packed_shape = Im2Col::packed_shape(&pool.input_shape, self.group)?;
88 Ok(ConcreteGeometry {
89 pool,
90 n,
91 k: self.k,
92 ci_per_group,
93 out_format: self.out_format.clone(),
94 patcher,
95 input_shape_with_n,
96 packed_shape,
97 })
98 }
99}
100
101impl Im2Col {
102 pub fn new(
103 pool_spec: PoolSpec,
104 group: usize,
105 k: usize,
106 input_full_shape: &ShapeFact,
107 mmm: Box<dyn MatMatMul>,
108 packing: usize,
109 ) -> TractResult<Im2Col> {
110 let out_format = dyn_clone::clone_box(&*mmm.packings()[packing].1);
111 let pool_geometry = pool_spec.compute_geo(input_full_shape)?;
112 let geometry: GeometryBound<_, _> =
113 SymbolicGeometry { group, pool_spec: pool_spec.clone(), pool_geometry, out_format, k }
114 .into();
115 let geometry = geometry.optimize_if(input_full_shape.as_concrete())?;
116 Ok(Im2Col { pool_spec, group, geometry })
117 }
118
119 fn packed_shape<D: DimLike>(
121 input_shape: &BaseDataShape<D, TVec<D>>,
122 group: usize,
123 ) -> TractResult<TVec<D>> {
124 let mut output_shape: TVec<D> = tvec!();
125 output_shape.push(input_shape.n().cloned().unwrap_or_else(|| 1.into()));
126 output_shape.push(group.into());
127 Ok(output_shape)
128 }
129}
130
131impl Op for Im2Col {
132 fn name(&self) -> StaticName {
133 "Im2col".into()
134 }
135
136 fn info(&self) -> TractResult<Vec<String>> {
137 Ok(vec![format!("groups:{}", self.group)])
138 }
139
140 op_as_typed_op!();
141}
142
143impl EvalOp for Im2Col {
144 fn is_stateless(&self) -> bool {
145 true
146 }
147
148 fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
149 let geometry = self.geometry.to_concrete(inputs[0].shape())?;
150 unsafe {
151 let mut input = inputs.remove(0).into_tensor();
152 let pad_value: Option<&Tensor> = if inputs.len() > 0 { Some(&inputs[0]) } else { None };
153 if !self.pool_spec.data_format.has_n() {
154 input.insert_axis(0)?;
155 }
156 let dt = input.datum_type();
157 let r = geometry.out_format.r();
158 let (single_panel_len, buf_align, zero_init) =
163 if let Some(pf) = geometry.out_format.downcast_ref::<PackedFormat>() {
164 (pf.single_panel_len(geometry.k), pf.alignment(), false)
165 } else if let Some(p4) = geometry.out_format.downcast_ref::<PackedI8K4>() {
166 (p4.single_panel_len(geometry.k), p4.alignment(), true)
167 } else {
168 bail!("Im2Col: unsupported packing format {:?}", geometry.out_format)
169 };
170 let panel_bytes = single_panel_len * dt.size_of();
171
172 let n_batches = *geometry.input_shape_with_n.n().unwrap_or(&1);
173 let n_groups = self.group;
174 let mut values: Vec<Box<dyn MMMInputValue>> = Vec::with_capacity(n_batches * n_groups);
175
176 for i in 0..n_batches {
177 let input = input.view_at_prefix(&[i])?;
178 for g in 0..n_groups {
179 let n =
180 if geometry.pool.output_shape.shape.contains(&0) { 0 } else { geometry.n };
181 let mut data = Tensor::uninitialized_aligned_dt(
182 dt,
183 &[n.divceil(r) * single_panel_len],
184 buf_align,
185 )?;
186 if zero_init {
187 data.as_bytes_mut().fill(0);
188 }
189 if n > 0 {
190 dispatch_copy_by_size!(Patcher::patch(dt)(
191 &geometry.patcher,
192 &geometry,
193 &input,
194 &mut data.view_mut(),
195 g,
196 pad_value
197 ))?;
198 }
199 values.push(Box::new(EagerPackedInput {
200 fact: PackedExoticFact {
201 format: geometry.out_format.clone(),
202 k: geometry.k,
203 mn: n.to_dim(),
204 },
205 packed: data.into_blob()?.into(),
206 panel_bytes: if n > 0 { panel_bytes } else { 0 },
207 mn: n,
208 }));
209 }
210 }
211
212 let output = PackedMatrixStorage::new_batched(&geometry.packed_shape, values)
213 .into_tensor(input.datum_type());
214 Ok(tvec!(output.into_tvalue()))
215 }
216 }
217}
218
219impl TypedOp for Im2Col {
220 as_op!();
221
222 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
223 let input_shape = self.pool_spec.data_format.shape(inputs[0].shape.to_tvec())?;
224 let output_shape = self.pool_spec.output_shape(&inputs[0].shape)?;
225 let mn = output_shape.hw_dims().iter().product::<TDim>();
226 let pof = PackedExoticFact {
227 format: dyn_clone::clone_box(self.geometry.out_format()),
228 k: self.geometry.k(),
229 mn,
230 };
231 Ok(tvec!(
232 inputs[0]
233 .datum_type
234 .fact(&[input_shape.n().cloned().unwrap_or(1.into()), self.group.into()])
235 .with_exotic_fact(pof)
236 ))
237 }
238
239 fn declutter(
240 &self,
241 model: &TypedModel,
242 node: &TypedNode,
243 ) -> TractResult<Option<TypedModelPatch>> {
244 let input_fact = model.outlet_fact(node.inputs[0])?;
245 if node.inputs.len() == 2
246 && model.outlet_fact(node.inputs[1])?.konst.as_ref().and_then(|t| t.as_uniform())
247 == Some(Tensor::zero_scalar_dt(input_fact.datum_type)?)
248 {
249 Ok(Some(
250 TypedModelPatch::replace_single_op(model, node, &node.inputs[0..1], self.clone())?
251 .with_context("b0 is zero"),
252 ))
253 } else {
254 Ok(None)
255 }
256 }
257}
258
259#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
260enum Patcher {
261 Generic,
262 Valid1d,
263 Valid2d,
264 Padded2d,
265}
266
267impl Patcher {
268 fn patch<'p, T: Copy + Datum + num_traits::Zero>(
269 &self,
270 geo: &'p ConcreteGeometry,
271 input: &TensorView,
272 pack: &'p mut TensorView,
273 g: usize,
274 pad_value: Option<&Tensor>,
275 ) -> TractResult<()> {
276 let ptr = unsafe { pack.as_slice_mut_unchecked::<T>().as_mut_ptr() };
280 if let Some(pf) = geo.out_format.downcast_ref::<PackedFormat>() {
281 let mut w = pf.write_with_k_outer(ptr, geo.k, geo.n);
282 self.run::<T, _>(geo, input, g, pad_value, &mut w)
283 } else if let Some(p4) = geo.out_format.downcast_ref::<PackedI8K4>() {
284 let mut w = p4.write_with_k_outer(ptr, geo.k, geo.n);
285 self.run::<T, _>(geo, input, g, pad_value, &mut w)
286 } else {
287 bail!("Im2Col: unsupported packing format {:?}", geo.out_format)
288 }
289 }
290
291 fn run<T: Copy + Datum + num_traits::Zero, W: PackingWriter<T>>(
292 &self,
293 geo: &ConcreteGeometry,
294 input: &TensorView,
295 g: usize,
296 pad_value: Option<&Tensor>,
297 writer: &mut W,
298 ) -> TractResult<()> {
299 match self {
300 Patcher::Valid1d => Self::valid_1d::<T, W>(geo, input, g, writer),
301 Patcher::Valid2d => Self::valid_2d::<T, W>(geo, input, g, writer),
302 Patcher::Padded2d => Self::padded_2d::<T, W>(
303 geo,
304 input,
305 g,
306 pad_value.unwrap_or(&Tensor::zero_scalar::<T>()?),
307 writer,
308 ),
309 _ => Self::generic::<T, W>(
310 geo,
311 input,
312 g,
313 pad_value.unwrap_or(&Tensor::zero_scalar::<T>()?),
314 writer,
315 ),
316 }
317 }
318
319 #[inline(never)]
320 fn generic<T: Copy + Datum, W: PackingWriter<T>>(
321 geometry: &ConcreteGeometry,
322 input: &TensorView,
323 g: usize,
324 pad_value: &Tensor,
325 writer: &mut W,
326 ) -> TractResult<()> {
327 unsafe {
328 let pad_value = *pad_value.to_scalar_unchecked();
329 let mut mega_matrix = Tensor::uninitialized::<T>(&[geometry.k, geometry.n])?;
330 let mut mega_matrix_view = mega_matrix.to_array_view_mut_unchecked::<T>();
331 let ptr = input.as_ptr_unchecked::<T>();
332 let ptr = ptr.add(geometry.input_shape_with_n.c_stride() * (g * geometry.ci_per_group));
333 for (spatial, mut col) in ndarray::indices(&*geometry.pool.patch.output_shape)
334 .into_iter()
335 .zip(mega_matrix_view.axis_iter_mut(Axis(1)))
336 {
337 let mut col = col.iter_mut();
338 for ci in 0..geometry.ci_per_group {
339 let ptr = ptr.add(geometry.input_shape_with_n.c_stride() * ci);
340 for v in geometry.pool.patch.at(spatial.slice()) {
341 *col.next().expect("geometry error in conv") =
342 v.map(|o| *ptr.offset(o)).unwrap_or(pad_value);
343 }
344 }
345 }
346 let mv = mega_matrix.as_slice_unchecked::<T>();
350 for kk in 0..geometry.k {
351 writer.write_slice(&mv[kk * geometry.n..(kk + 1) * geometry.n]);
352 }
353 Ok(())
354 }
355 }
356
357 #[inline(never)]
358 fn valid_1d<T: Copy + Datum, W: PackingWriter<T>>(
359 geometry: &ConcreteGeometry,
360 input: &TensorView,
361 g: usize,
362 writer: &mut W,
363 ) -> TractResult<()> {
364 unsafe {
365 let x_stride = *geometry.input_shape_with_n.h_stride() as isize
366 * geometry.pool.patch.spec.strides[0] as isize;
367 let c_stride = *geometry.input_shape_with_n.c_stride() as isize;
368 let iptr = input.as_ptr_unchecked::<T>();
369 let iptr = iptr.add(g * geometry.ci_per_group * geometry.input_shape_with_n.c_stride());
370 let output_x = *geometry.pool.patch.output_shape.get_unchecked(0);
371 let contiguous_x = x_stride == 1;
376 for ci in 0..geometry.ci_per_group {
377 let iptr = iptr.offset(ci as isize * c_stride);
378 for koffset in &geometry.pool.patch.standard_layout_data_field {
379 let iptr = iptr.offset(*koffset);
380 if contiguous_x {
381 let row = std::slice::from_raw_parts(iptr, output_x);
382 writer.write_slice(row);
383 } else {
384 let mut iptr_x = iptr;
386 for _ in 0..output_x {
387 writer.write(*iptr_x);
388 iptr_x = iptr_x.offset(x_stride);
389 }
390 }
391 }
392 }
393 Ok(())
394 }
395 }
396
397 #[inline(never)]
398 fn padded_2d<T: Copy + Datum, W: PackingWriter<T>>(
399 geometry: &ConcreteGeometry,
400 input: &TensorView,
401 g: usize,
402 pad_value: &Tensor,
403 writer: &mut W,
404 ) -> TractResult<()> {
405 unsafe {
406 let pad_value = *pad_value.to_scalar_unchecked();
407 let y_stride = geometry.pool.patch.spec.strides[0] as isize;
408 let x_stride = geometry.pool.patch.spec.strides[1] as isize;
409 let shape = &geometry.input_shape_with_n;
410 let y_stride_ptr = y_stride * *shape.h_stride() as isize;
411 let x_stride_ptr = x_stride * *shape.w_stride() as isize;
412 let c_stride_ptr = *shape.c_stride() as isize;
413 let input_heigth = shape.hw_dims()[0] as isize;
414 let input_width = shape.hw_dims()[1] as isize;
415 let kernel_len = geometry.pool.patch.standard_layout_data_field.len();
416 let iptr = input.as_ptr_unchecked::<T>();
417 let iptr = iptr.add(g * geometry.ci_per_group * shape.c_stride());
418 let output_width = *geometry.pool.patch.output_shape.get_unchecked(1);
419 for ci in 0..geometry.ci_per_group {
420 let iptr = iptr.offset(ci as isize * c_stride_ptr);
421 for kitem in 0..kernel_len {
422 let dy = *geometry.pool.patch.data_field.as_ptr().offset(kitem as isize * 2);
423 let dx =
424 *geometry.pool.patch.data_field.as_ptr().offset(1 + kitem as isize * 2);
425 let valid_x_start =
426 Integer::div_ceil(&-dx, &x_stride).max(0).min(output_width as _);
427 let valid_x_end = Integer::div_ceil(&(input_width - dx), &x_stride)
428 .max(0)
429 .min(output_width as _);
430
431 let iptr = iptr.offset(
432 *geometry.pool.patch.standard_layout_data_field.get_unchecked(kitem),
433 );
434 for yo in 0..*geometry.pool.patch.output_shape.get_unchecked(0) {
435 let y = yo as isize * y_stride + dy;
436 let iptr = iptr.offset(yo as isize * y_stride_ptr);
437 if y >= 0 && y < input_heigth {
438 Self::padded_2d_invalid_x_loop(
439 valid_x_start as usize,
440 pad_value,
441 &mut *writer,
442 );
443 Self::padded_2d_valid_x_loop(
444 valid_x_start,
445 valid_x_end,
446 x_stride_ptr,
447 iptr,
448 &mut *writer,
449 );
450 Self::padded_2d_invalid_x_loop(
451 output_width - valid_x_end as usize,
452 pad_value,
453 &mut *writer,
454 );
455 } else {
456 Self::padded_2d_invalid_x_loop(output_width, pad_value, &mut *writer);
457 }
458 }
459 }
460 }
461 }
462 Ok(())
463 }
464
465 #[inline(never)]
466 unsafe fn padded_2d_invalid_x_loop<T: Copy + Datum, W: PackingWriter<T>>(
467 count: usize,
468 pad_value: T,
469 writer: &mut W,
470 ) {
471 for _ in 0..count {
472 writer.write(pad_value);
473 }
474 }
475
476 #[inline(never)]
477 unsafe fn padded_2d_valid_x_loop<T: Copy + Datum, W: PackingWriter<T>>(
478 x_min: isize,
479 x_max: isize,
480 x_stride_ptr: isize,
481 iptr: *const T,
482 writer: &mut W,
483 ) {
484 if x_stride_ptr == 1 && x_max > x_min {
488 unsafe {
489 let row = std::slice::from_raw_parts(iptr.offset(x_min), (x_max - x_min) as usize);
490 writer.write_slice(row);
491 }
492 } else {
493 for x in x_min..x_max {
494 writer.write(unsafe { *iptr.offset(x * x_stride_ptr) });
495 }
496 }
497 }
498
499 #[inline(never)]
500 fn valid_2d<T: Copy + Datum, W: PackingWriter<T>>(
501 geometry: &ConcreteGeometry,
502 input: &TensorView,
503 g: usize,
504 writer: &mut W,
505 ) -> TractResult<()> {
506 unsafe {
507 let shape = &geometry.input_shape_with_n;
508 let y_stride = geometry.pool.patch.spec.strides[0] as isize;
509 let x_stride = geometry.pool.patch.spec.strides[1] as isize;
510 let y_stride_ptr = y_stride * *shape.h_stride() as isize;
511 let x_stride_ptr = x_stride * *shape.w_stride() as isize;
512 let c_stride_ptr = *shape.c_stride() as isize;
513 let iptr = input.as_ptr_unchecked::<T>();
514 let iptr = iptr.add(g * geometry.ci_per_group * shape.c_stride());
515 let output_y = *geometry.pool.patch.output_shape.get_unchecked(0);
516 let output_x = *geometry.pool.patch.output_shape.get_unchecked(1);
517 let contiguous_x = x_stride_ptr == 1;
521 for ci in 0..geometry.ci_per_group {
522 let iptr = iptr.offset(ci as isize * c_stride_ptr);
523 for koffset in &geometry.pool.patch.standard_layout_data_field {
524 let iptr = iptr.offset(*koffset);
525 let mut iptr_y = iptr;
526 for _ in 0..output_y {
527 if contiguous_x {
528 let row = std::slice::from_raw_parts(iptr_y, output_x);
529 writer.write_slice(row);
530 } else {
531 let mut iptr_x = iptr_y;
533 for _ in 0..output_x {
534 writer.write(*iptr_x);
535 iptr_x = iptr_x.offset(x_stride_ptr);
536 }
537 }
538 iptr_y = iptr_y.offset(y_stride_ptr);
539 }
540 }
541 }
542 Ok(())
543 }
544 }
545}