Skip to main content

tract_core/ops/cnn/conv/
im2col.rs

1use tract_linalg::mmm::{EagerPackedInput, MMMInputValue, MatMatMul, PackedOpaqueFact};
2use tract_linalg::pack::{PackedFormat, PackingWriter};
3
4use crate::internal::*;
5use ndarray::prelude::*;
6use num_integer::Integer;
7
8use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry};
9use crate::ops::cnn::{GeometryBound, PoolSpec, ResolveTo};
10use crate::ops::nn::{BaseDataShape, DataFormat, DataShape};
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
13pub struct Im2Col {
14    pub pool_spec: PoolSpec,
15    pub group: usize,
16    geometry: GeometryBound<SymbolicGeometry, ConcreteGeometry>,
17}
18
19#[derive(Debug, Clone, Hash, PartialEq, Eq)]
20struct SymbolicGeometry {
21    group: usize,
22    pool_spec: PoolSpec,
23    pool_geometry: PoolGeometry,
24    b_pack: PackedFormat,
25    k: usize,
26}
27
28#[derive(Debug, Clone, Hash, PartialEq, Eq)]
29struct ConcreteGeometry {
30    pool: ConcretePoolGeometry,
31    pub n: usize,
32    k: usize,
33    pub b_pack: PackedFormat,
34    pub ci_per_group: usize,
35    patcher: Patcher,
36    input_shape_with_n: DataShape,
37    packed_shape: TVec<usize>, // always Batch,Group
38}
39
40impl GeometryBound<SymbolicGeometry, ConcreteGeometry> {
41    pub fn b_pack(&self) -> &PackedFormat {
42        match self {
43            GeometryBound::Symbolic(s) => &s.b_pack,
44            GeometryBound::Concrete(s) => &s.b_pack,
45        }
46    }
47    pub fn k(&self) -> usize {
48        match self {
49            GeometryBound::Symbolic(s) => s.k,
50            GeometryBound::Concrete(s) => s.k,
51        }
52    }
53}
54
55impl ResolveTo<ConcreteGeometry> for SymbolicGeometry {
56    type Param = [usize];
57    fn resolve(&self, input_full_shape: &[usize]) -> TractResult<ConcreteGeometry> {
58        let pool = self.pool_geometry.to_concrete(input_full_shape)?.into_owned();
59        let patcher = if !pool.patch.padded && pool.patch.rank() == 2 {
60            Patcher::Valid2d
61        } else if pool.patch.rank() == 2 {
62            Patcher::Padded2d
63        } else if !pool.patch.padded && pool.patch.rank() == 1 {
64            Patcher::Valid1d
65        } else {
66            Patcher::Generic
67        };
68        let ci_per_group = pool.input_shape.c_dim() / self.group;
69        let n = pool.output_shape.hw_dims().iter().product();
70        let input_shape_with_n = match self.pool_spec.data_format {
71            DataFormat::HWC => DataFormat::NHWC.from_n_c_hw(
72                1,
73                *pool.input_shape.c(),
74                pool.input_shape.hw_dims(),
75            )?,
76            DataFormat::CHW => DataFormat::NCHW.from_n_c_hw(
77                1,
78                *pool.input_shape.c(),
79                pool.input_shape.hw_dims(),
80            )?,
81            _ => pool.input_shape.clone(),
82        };
83        let packed_shape = Im2Col::packed_shape(&pool.input_shape, self.group)?;
84        Ok(ConcreteGeometry {
85            pool,
86            n,
87            k: self.k,
88            ci_per_group,
89            b_pack: self.b_pack.clone(),
90            patcher,
91            input_shape_with_n,
92            packed_shape,
93        })
94    }
95}
96
97impl Im2Col {
98    pub fn new(
99        pool_spec: PoolSpec,
100        group: usize,
101        k: usize,
102        input_full_shape: &ShapeFact,
103        mmm: Box<dyn MatMatMul>,
104        packing: usize,
105    ) -> TractResult<Im2Col> {
106        let b_pack = mmm.packings()[packing]
107            .1
108            .downcast_ref::<PackedFormat>()
109            .context("Im2Col expects regular packed format")?
110            .clone();
111
112        let pool_geometry = pool_spec.compute_geo(input_full_shape)?;
113        let geometry: GeometryBound<_, _> =
114            SymbolicGeometry { group, pool_spec: pool_spec.clone(), pool_geometry, b_pack, k }
115                .into();
116        let geometry = geometry.optimize_if(input_full_shape.as_concrete())?;
117        Ok(Im2Col { pool_spec, group, geometry })
118    }
119
120    // packed shape is Batch,Group
121    fn packed_shape<D: DimLike>(
122        input_shape: &BaseDataShape<D, TVec<D>>,
123        group: usize,
124    ) -> TractResult<TVec<D>> {
125        let mut output_shape: TVec<D> = tvec!();
126        output_shape.push(input_shape.n().cloned().unwrap_or_else(|| 1.into()));
127        output_shape.push(group.into());
128        Ok(output_shape)
129    }
130}
131
132impl Op for Im2Col {
133    fn name(&self) -> StaticName {
134        "Im2col".into()
135    }
136
137    fn info(&self) -> TractResult<Vec<String>> {
138        Ok(vec![format!("groups:{}", self.group)])
139    }
140
141    impl_op_same_as!();
142    op_as_typed_op!();
143}
144
145impl EvalOp for Im2Col {
146    fn is_stateless(&self) -> bool {
147        true
148    }
149
150    fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
151        let geometry = self.geometry.to_concrete(inputs[0].shape())?;
152        unsafe {
153            let mut input = inputs.remove(0).into_tensor();
154            let pad_value: Option<&Tensor> = if inputs.len() > 0 { Some(&inputs[0]) } else { None };
155            let mut output = Tensor::uninitialized::<Opaque>(&geometry.packed_shape)?;
156            if !self.pool_spec.data_format.has_n() {
157                input.insert_axis(0)?;
158            }
159            let mut output_dense = output.try_as_dense_mut()?;
160            let mut output_view = output_dense.to_array_view_mut::<Opaque>()?;
161            let panel_bytes =
162                geometry.b_pack.single_panel_len(geometry.k) * input.datum_type().size_of();
163
164            // in the loop, we have normalized the input so that N is
165            // always here, and output so that N and G are there.
166            if !geometry.pool.output_shape.shape.contains(&0) {
167                for i in 0..*geometry.input_shape_with_n.n().unwrap_or(&1) {
168                    let input = input.view_at_prefix(&[i])?;
169                    for g in 0..self.group {
170                        let mut data = Tensor::uninitialized_aligned_dt(
171                            input.datum_type(),
172                            &[geometry.b_pack.len(geometry.k, geometry.n)],
173                            geometry.b_pack.alignment(),
174                        )?;
175                        dispatch_copy_by_size!(Patcher::patch(input.datum_type())(
176                            &geometry.patcher,
177                            &geometry,
178                            &input,
179                            &mut data.view_mut(),
180                            g,
181                            pad_value
182                        ))?;
183                        let input: Box<dyn MMMInputValue> = Box::new(EagerPackedInput {
184                            fact: PackedOpaqueFact {
185                                format: Box::new(geometry.b_pack.clone()),
186                                k: geometry.k,
187                                mn: geometry.n.to_dim(),
188                            },
189                            packed: data.into_blob()?.into(),
190                            panel_bytes,
191                            mn: geometry.n,
192                        });
193                        output_view[[i, g]] = input.into();
194                    }
195                }
196            }
197            Ok(tvec!(output.into_tvalue()))
198        }
199    }
200}
201
202impl TypedOp for Im2Col {
203    as_op!();
204
205    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
206        let input_shape = self.pool_spec.data_format.shape(inputs[0].shape.to_tvec())?;
207        let output_shape = self.pool_spec.output_shape(&inputs[0].shape)?;
208        let mn = output_shape.hw_dims().iter().product::<TDim>();
209        let pof = PackedOpaqueFact {
210            format: Box::new(self.geometry.b_pack().clone()),
211            k: self.geometry.k(),
212            mn,
213        };
214        Ok(tvec!(
215            Opaque::fact(&[input_shape.n().cloned().unwrap_or(1.into()), self.group.into()])
216                .with_opaque_fact(pof)
217        ))
218    }
219
220    fn declutter(
221        &self,
222        model: &TypedModel,
223        node: &TypedNode,
224    ) -> TractResult<Option<TypedModelPatch>> {
225        let input_fact = model.outlet_fact(node.inputs[0])?;
226        if node.inputs.len() == 2
227            && model.outlet_fact(node.inputs[1])?.konst.as_ref().and_then(|t| t.as_uniform())
228                == Some(Tensor::zero_scalar_dt(input_fact.datum_type)?)
229        {
230            Ok(Some(
231                TypedModelPatch::replace_single_op(model, node, &node.inputs[0..1], self.clone())?
232                    .with_context("b0 is zero"),
233            ))
234        } else {
235            Ok(None)
236        }
237    }
238}
239
240#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
241enum Patcher {
242    Generic,
243    Valid1d,
244    Valid2d,
245    Padded2d,
246}
247
248impl Patcher {
249    fn patch<'p, T: Copy + Datum + num_traits::Zero>(
250        &self,
251        geo: &'p ConcreteGeometry,
252        input: &TensorView,
253        pack: &'p mut TensorView,
254        g: usize,
255        pad_value: Option<&Tensor>,
256    ) -> TractResult<()> {
257        match self {
258            Patcher::Valid1d => Self::valid_1d::<T>(geo, input, pack, g),
259            Patcher::Valid2d => Self::valid_2d::<T>(geo, input, pack, g),
260            Patcher::Padded2d => Self::padded_2d::<T>(
261                geo,
262                input,
263                pack,
264                g,
265                pad_value.unwrap_or(&Tensor::zero_scalar::<T>()?),
266            ),
267            _ => Self::generic::<T>(
268                geo,
269                input,
270                pack,
271                g,
272                pad_value.unwrap_or(&Tensor::zero_scalar::<T>()?),
273            ),
274        }
275    }
276
277    #[inline(never)]
278    fn generic<'p, T: Copy + Datum>(
279        geometry: &'p ConcreteGeometry,
280        input: &TensorView,
281        pack: &'p mut TensorView,
282        g: usize,
283        pad_value: &Tensor,
284    ) -> TractResult<()> {
285        unsafe {
286            let pad_value = *pad_value.to_scalar_unchecked();
287            let mut mega_matrix = Tensor::uninitialized::<T>(&[geometry.k, geometry.n])?;
288            let mut mega_matrix_view = mega_matrix.to_array_view_mut_unchecked::<T>();
289            let ptr = input.as_ptr_unchecked::<T>();
290            let ptr = ptr.add(geometry.input_shape_with_n.c_stride() * (g * geometry.ci_per_group));
291            for (spatial, mut col) in ndarray::indices(&*geometry.pool.patch.output_shape)
292                .into_iter()
293                .zip(mega_matrix_view.axis_iter_mut(Axis(1)))
294            {
295                let mut col = col.iter_mut();
296                for ci in 0..geometry.ci_per_group {
297                    let ptr = ptr.add(geometry.input_shape_with_n.c_stride() * ci);
298                    for v in geometry.pool.patch.at(spatial.slice()) {
299                        *col.next().expect("geometry error in conv") =
300                            v.map(|o| *ptr.offset(o)).unwrap_or(pad_value);
301                    }
302                }
303            }
304            geometry.b_pack.pack(pack, mega_matrix.view(), 0, 1);
305            Ok(())
306        }
307    }
308
309    #[inline(never)]
310    fn valid_1d<'p, T: Copy + Datum>(
311        geometry: &'p ConcreteGeometry,
312        input: &TensorView,
313        pack: &'p mut TensorView,
314        g: usize,
315    ) -> TractResult<()> {
316        unsafe {
317            let x_stride = *geometry.input_shape_with_n.h_stride() as isize
318                * geometry.pool.patch.spec.strides[0] as isize;
319            let c_stride = *geometry.input_shape_with_n.c_stride() as isize;
320            let pack = pack.as_slice_mut_unchecked::<T>();
321            let mut writer =
322                geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n);
323            let iptr = input.as_ptr_unchecked::<T>();
324            let iptr = iptr.add(g * geometry.ci_per_group * geometry.input_shape_with_n.c_stride());
325            for ci in 0..geometry.ci_per_group {
326                let iptr = iptr.offset(ci as isize * c_stride);
327                for koffset in &geometry.pool.patch.standard_layout_data_field {
328                    let iptr = iptr.offset(*koffset);
329                    for x in 0..*geometry.pool.patch.output_shape.get_unchecked(0) {
330                        writer.write(*iptr.offset(x as isize * x_stride));
331                    }
332                }
333            }
334            Ok(())
335        }
336    }
337
338    #[inline(never)]
339    fn padded_2d<'p, T: Copy + Datum>(
340        geometry: &'p ConcreteGeometry,
341        input: &TensorView,
342        pack: &'p mut TensorView,
343        g: usize,
344        pad_value: &Tensor,
345    ) -> TractResult<()> {
346        unsafe {
347            let pad_value = *pad_value.to_scalar_unchecked();
348            let pack = pack.as_slice_mut_unchecked::<T>();
349            let y_stride = geometry.pool.patch.spec.strides[0] as isize;
350            let x_stride = geometry.pool.patch.spec.strides[1] as isize;
351            let shape = &geometry.input_shape_with_n;
352            let y_stride_ptr = y_stride * *shape.h_stride() as isize;
353            let x_stride_ptr = x_stride * *shape.w_stride() as isize;
354            let c_stride_ptr = *shape.c_stride() as isize;
355            let input_heigth = shape.hw_dims()[0] as isize;
356            let input_width = shape.hw_dims()[1] as isize;
357            let kernel_len = geometry.pool.patch.standard_layout_data_field.len();
358            let mut writer =
359                geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n);
360            let iptr = input.as_ptr_unchecked::<T>();
361            let iptr = iptr.add(g * geometry.ci_per_group * shape.c_stride());
362            let output_width = *geometry.pool.patch.output_shape.get_unchecked(1);
363            for ci in 0..geometry.ci_per_group {
364                let iptr = iptr.offset(ci as isize * c_stride_ptr);
365                for kitem in 0..kernel_len {
366                    let dy = *geometry.pool.patch.data_field.as_ptr().offset(kitem as isize * 2);
367                    let dx =
368                        *geometry.pool.patch.data_field.as_ptr().offset(1 + kitem as isize * 2);
369                    let valid_x_start =
370                        Integer::div_ceil(&-dx, &x_stride).max(0).min(output_width as _);
371                    let valid_x_end = Integer::div_ceil(&(input_width - dx), &x_stride)
372                        .max(0)
373                        .min(output_width as _);
374
375                    let iptr = iptr.offset(
376                        *geometry.pool.patch.standard_layout_data_field.get_unchecked(kitem),
377                    );
378                    for yo in 0..*geometry.pool.patch.output_shape.get_unchecked(0) {
379                        let y = yo as isize * y_stride + dy;
380                        let iptr = iptr.offset(yo as isize * y_stride_ptr);
381                        if y >= 0 && y < input_heigth {
382                            Self::padded_2d_invalid_x_loop(
383                                valid_x_start as usize,
384                                pad_value,
385                                &mut writer,
386                            );
387                            Self::padded_2d_valid_x_loop(
388                                valid_x_start,
389                                valid_x_end,
390                                x_stride_ptr,
391                                iptr,
392                                &mut writer,
393                            );
394                            Self::padded_2d_invalid_x_loop(
395                                output_width - valid_x_end as usize,
396                                pad_value,
397                                &mut writer,
398                            );
399                        } else {
400                            Self::padded_2d_invalid_x_loop(output_width, pad_value, &mut writer);
401                        }
402                    }
403                }
404            }
405        }
406        Ok(())
407    }
408
409    #[inline(never)]
410    unsafe fn padded_2d_invalid_x_loop<T: Copy + Datum>(
411        count: usize,
412        pad_value: T,
413        writer: &mut tract_linalg::pack::KOutWriter<T>,
414    ) {
415        for _ in 0..count {
416            writer.write(pad_value);
417        }
418    }
419
420    #[inline(never)]
421    unsafe fn padded_2d_valid_x_loop<T: Copy + Datum>(
422        x_min: isize,
423        x_max: isize,
424        x_stride_ptr: isize,
425        iptr: *const T,
426        writer: &mut tract_linalg::pack::KOutWriter<T>,
427    ) {
428        for x in x_min..x_max {
429            writer.write(unsafe { *iptr.offset(x * x_stride_ptr) });
430        }
431    }
432
433    #[inline(never)]
434    fn valid_2d<'p, T: Copy + Datum>(
435        geometry: &'p ConcreteGeometry,
436        input: &TensorView,
437        pack: &'p mut TensorView,
438        g: usize,
439    ) -> TractResult<()> {
440        unsafe {
441            let pack = pack.as_slice_mut_unchecked::<T>();
442            let shape = &geometry.input_shape_with_n;
443            let y_stride = geometry.pool.patch.spec.strides[0] as isize;
444            let x_stride = geometry.pool.patch.spec.strides[1] as isize;
445            let y_stride_ptr = y_stride * *shape.h_stride() as isize;
446            let x_stride_ptr = x_stride * *shape.w_stride() as isize;
447            let c_stride_ptr = *shape.c_stride() as isize;
448            let mut writer =
449                geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n);
450            let iptr = input.as_ptr_unchecked::<T>();
451            let iptr = iptr.add(g * geometry.ci_per_group * shape.c_stride());
452            for ci in 0..geometry.ci_per_group {
453                let iptr = iptr.offset(ci as isize * c_stride_ptr);
454                for koffset in &geometry.pool.patch.standard_layout_data_field {
455                    let iptr = iptr.offset(*koffset);
456                    for y in 0..*geometry.pool.patch.output_shape.get_unchecked(0) {
457                        let iptr = iptr.offset(y as isize * y_stride_ptr);
458                        for x in 0..*geometry.pool.patch.output_shape.get_unchecked(1) {
459                            writer.write(*iptr.offset(x as isize * x_stride_ptr));
460                        }
461                    }
462                }
463            }
464            Ok(())
465        }
466    }
467}