Skip to main content

tract_core/ops/cnn/conv/
im2col.rs

1use tract_linalg::mmm::{
2    EagerPackedInput, MMMInputValue, MatMatMul, PackedExoticFact, PackedMatrixStorage,
3};
4use tract_linalg::pack::{PackedFormat, PackingWriter};
5
6use crate::internal::*;
7use ndarray::prelude::*;
8use num_integer::Integer;
9
10use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry};
11use crate::ops::cnn::{GeometryBound, PoolSpec, ResolveTo};
12use crate::ops::nn::{BaseDataShape, DataFormat, DataShape};
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub struct Im2Col {
16    pub pool_spec: PoolSpec,
17    pub group: usize,
18    geometry: GeometryBound<SymbolicGeometry, ConcreteGeometry>,
19}
20
21#[derive(Debug, Clone, Hash, PartialEq, Eq)]
22struct SymbolicGeometry {
23    group: usize,
24    pool_spec: PoolSpec,
25    pool_geometry: PoolGeometry,
26    b_pack: PackedFormat,
27    k: usize,
28}
29
30#[derive(Debug, Clone, Hash, PartialEq, Eq)]
31struct ConcreteGeometry {
32    pool: ConcretePoolGeometry,
33    pub n: usize,
34    k: usize,
35    pub b_pack: PackedFormat,
36    pub ci_per_group: usize,
37    patcher: Patcher,
38    input_shape_with_n: DataShape,
39    packed_shape: TVec<usize>, // always Batch,Group
40}
41
42impl GeometryBound<SymbolicGeometry, ConcreteGeometry> {
43    pub fn b_pack(&self) -> &PackedFormat {
44        match self {
45            GeometryBound::Symbolic(s) => &s.b_pack,
46            GeometryBound::Concrete(s) => &s.b_pack,
47        }
48    }
49    pub fn k(&self) -> usize {
50        match self {
51            GeometryBound::Symbolic(s) => s.k,
52            GeometryBound::Concrete(s) => s.k,
53        }
54    }
55}
56
57impl ResolveTo<ConcreteGeometry> for SymbolicGeometry {
58    type Param = [usize];
59    fn resolve(&self, input_full_shape: &[usize]) -> TractResult<ConcreteGeometry> {
60        let pool = self.pool_geometry.to_concrete(input_full_shape)?.into_owned();
61        let patcher = if !pool.patch.padded && pool.patch.rank() == 2 {
62            Patcher::Valid2d
63        } else if pool.patch.rank() == 2 {
64            Patcher::Padded2d
65        } else if !pool.patch.padded && pool.patch.rank() == 1 {
66            Patcher::Valid1d
67        } else {
68            Patcher::Generic
69        };
70        let ci_per_group = pool.input_shape.c_dim() / self.group;
71        let n = pool.output_shape.hw_dims().iter().product();
72        let input_shape_with_n = match self.pool_spec.data_format {
73            DataFormat::HWC => DataFormat::NHWC.from_n_c_hw(
74                1,
75                *pool.input_shape.c(),
76                pool.input_shape.hw_dims(),
77            )?,
78            DataFormat::CHW => DataFormat::NCHW.from_n_c_hw(
79                1,
80                *pool.input_shape.c(),
81                pool.input_shape.hw_dims(),
82            )?,
83            _ => pool.input_shape.clone(),
84        };
85        let packed_shape = Im2Col::packed_shape(&pool.input_shape, self.group)?;
86        Ok(ConcreteGeometry {
87            pool,
88            n,
89            k: self.k,
90            ci_per_group,
91            b_pack: self.b_pack.clone(),
92            patcher,
93            input_shape_with_n,
94            packed_shape,
95        })
96    }
97}
98
99impl Im2Col {
100    pub fn new(
101        pool_spec: PoolSpec,
102        group: usize,
103        k: usize,
104        input_full_shape: &ShapeFact,
105        mmm: Box<dyn MatMatMul>,
106        packing: usize,
107    ) -> TractResult<Im2Col> {
108        let b_pack = mmm.packings()[packing]
109            .1
110            .downcast_ref::<PackedFormat>()
111            .context("Im2Col expects regular packed format")?
112            .clone();
113
114        let pool_geometry = pool_spec.compute_geo(input_full_shape)?;
115        let geometry: GeometryBound<_, _> =
116            SymbolicGeometry { group, pool_spec: pool_spec.clone(), pool_geometry, b_pack, k }
117                .into();
118        let geometry = geometry.optimize_if(input_full_shape.as_concrete())?;
119        Ok(Im2Col { pool_spec, group, geometry })
120    }
121
122    // packed shape is Batch,Group
123    fn packed_shape<D: DimLike>(
124        input_shape: &BaseDataShape<D, TVec<D>>,
125        group: usize,
126    ) -> TractResult<TVec<D>> {
127        let mut output_shape: TVec<D> = tvec!();
128        output_shape.push(input_shape.n().cloned().unwrap_or_else(|| 1.into()));
129        output_shape.push(group.into());
130        Ok(output_shape)
131    }
132}
133
134impl Op for Im2Col {
135    fn name(&self) -> StaticName {
136        "Im2col".into()
137    }
138
139    fn info(&self) -> TractResult<Vec<String>> {
140        Ok(vec![format!("groups:{}", self.group)])
141    }
142
143    op_as_typed_op!();
144}
145
146impl EvalOp for Im2Col {
147    fn is_stateless(&self) -> bool {
148        true
149    }
150
151    fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
152        let geometry = self.geometry.to_concrete(inputs[0].shape())?;
153        unsafe {
154            let mut input = inputs.remove(0).into_tensor();
155            let pad_value: Option<&Tensor> = if inputs.len() > 0 { Some(&inputs[0]) } else { None };
156            if !self.pool_spec.data_format.has_n() {
157                input.insert_axis(0)?;
158            }
159            let panel_bytes =
160                geometry.b_pack.single_panel_len(geometry.k) * input.datum_type().size_of();
161
162            let n_batches = *geometry.input_shape_with_n.n().unwrap_or(&1);
163            let n_groups = self.group;
164            let mut values: Vec<Box<dyn MMMInputValue>> = Vec::with_capacity(n_batches * n_groups);
165
166            for i in 0..n_batches {
167                let input = input.view_at_prefix(&[i])?;
168                for g in 0..n_groups {
169                    let n =
170                        if geometry.pool.output_shape.shape.contains(&0) { 0 } else { geometry.n };
171                    let mut data = Tensor::uninitialized_aligned_dt(
172                        input.datum_type(),
173                        &[geometry.b_pack.len(geometry.k, n)],
174                        geometry.b_pack.alignment(),
175                    )?;
176                    if n > 0 {
177                        dispatch_copy_by_size!(Patcher::patch(input.datum_type())(
178                            &geometry.patcher,
179                            &geometry,
180                            &input,
181                            &mut data.view_mut(),
182                            g,
183                            pad_value
184                        ))?;
185                    }
186                    values.push(Box::new(EagerPackedInput {
187                        fact: PackedExoticFact {
188                            format: Box::new(geometry.b_pack.clone()),
189                            k: geometry.k,
190                            mn: n.to_dim(),
191                        },
192                        packed: data.into_blob()?.into(),
193                        panel_bytes: if n > 0 { panel_bytes } else { 0 },
194                        mn: n,
195                    }));
196                }
197            }
198
199            let output = PackedMatrixStorage::new_batched(&geometry.packed_shape, values)
200                .into_tensor(input.datum_type());
201            Ok(tvec!(output.into_tvalue()))
202        }
203    }
204}
205
206impl TypedOp for Im2Col {
207    as_op!();
208
209    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
210        let input_shape = self.pool_spec.data_format.shape(inputs[0].shape.to_tvec())?;
211        let output_shape = self.pool_spec.output_shape(&inputs[0].shape)?;
212        let mn = output_shape.hw_dims().iter().product::<TDim>();
213        let pof = PackedExoticFact {
214            format: Box::new(self.geometry.b_pack().clone()),
215            k: self.geometry.k(),
216            mn,
217        };
218        Ok(tvec!(
219            inputs[0]
220                .datum_type
221                .fact(&[input_shape.n().cloned().unwrap_or(1.into()), self.group.into()])
222                .with_exotic_fact(pof)
223        ))
224    }
225
226    fn declutter(
227        &self,
228        model: &TypedModel,
229        node: &TypedNode,
230    ) -> TractResult<Option<TypedModelPatch>> {
231        let input_fact = model.outlet_fact(node.inputs[0])?;
232        if node.inputs.len() == 2
233            && model.outlet_fact(node.inputs[1])?.konst.as_ref().and_then(|t| t.as_uniform())
234                == Some(Tensor::zero_scalar_dt(input_fact.datum_type)?)
235        {
236            Ok(Some(
237                TypedModelPatch::replace_single_op(model, node, &node.inputs[0..1], self.clone())?
238                    .with_context("b0 is zero"),
239            ))
240        } else {
241            Ok(None)
242        }
243    }
244}
245
246#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
247enum Patcher {
248    Generic,
249    Valid1d,
250    Valid2d,
251    Padded2d,
252}
253
254impl Patcher {
255    fn patch<'p, T: Copy + Datum + num_traits::Zero>(
256        &self,
257        geo: &'p ConcreteGeometry,
258        input: &TensorView,
259        pack: &'p mut TensorView,
260        g: usize,
261        pad_value: Option<&Tensor>,
262    ) -> TractResult<()> {
263        match self {
264            Patcher::Valid1d => Self::valid_1d::<T>(geo, input, pack, g),
265            Patcher::Valid2d => Self::valid_2d::<T>(geo, input, pack, g),
266            Patcher::Padded2d => Self::padded_2d::<T>(
267                geo,
268                input,
269                pack,
270                g,
271                pad_value.unwrap_or(&Tensor::zero_scalar::<T>()?),
272            ),
273            _ => Self::generic::<T>(
274                geo,
275                input,
276                pack,
277                g,
278                pad_value.unwrap_or(&Tensor::zero_scalar::<T>()?),
279            ),
280        }
281    }
282
283    #[inline(never)]
284    fn generic<'p, T: Copy + Datum>(
285        geometry: &'p ConcreteGeometry,
286        input: &TensorView,
287        pack: &'p mut TensorView,
288        g: usize,
289        pad_value: &Tensor,
290    ) -> TractResult<()> {
291        unsafe {
292            let pad_value = *pad_value.to_scalar_unchecked();
293            let mut mega_matrix = Tensor::uninitialized::<T>(&[geometry.k, geometry.n])?;
294            let mut mega_matrix_view = mega_matrix.to_array_view_mut_unchecked::<T>();
295            let ptr = input.as_ptr_unchecked::<T>();
296            let ptr = ptr.add(geometry.input_shape_with_n.c_stride() * (g * geometry.ci_per_group));
297            for (spatial, mut col) in ndarray::indices(&*geometry.pool.patch.output_shape)
298                .into_iter()
299                .zip(mega_matrix_view.axis_iter_mut(Axis(1)))
300            {
301                let mut col = col.iter_mut();
302                for ci in 0..geometry.ci_per_group {
303                    let ptr = ptr.add(geometry.input_shape_with_n.c_stride() * ci);
304                    for v in geometry.pool.patch.at(spatial.slice()) {
305                        *col.next().expect("geometry error in conv") =
306                            v.map(|o| *ptr.offset(o)).unwrap_or(pad_value);
307                    }
308                }
309            }
310            geometry.b_pack.pack(pack, mega_matrix.view(), 0, 1);
311            Ok(())
312        }
313    }
314
315    #[inline(never)]
316    fn valid_1d<'p, T: Copy + Datum>(
317        geometry: &'p ConcreteGeometry,
318        input: &TensorView,
319        pack: &'p mut TensorView,
320        g: usize,
321    ) -> TractResult<()> {
322        unsafe {
323            let x_stride = *geometry.input_shape_with_n.h_stride() as isize
324                * geometry.pool.patch.spec.strides[0] as isize;
325            let c_stride = *geometry.input_shape_with_n.c_stride() as isize;
326            let pack = pack.as_slice_mut_unchecked::<T>();
327            let mut writer =
328                geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n);
329            let iptr = input.as_ptr_unchecked::<T>();
330            let iptr = iptr.add(g * geometry.ci_per_group * geometry.input_shape_with_n.c_stride());
331            for ci in 0..geometry.ci_per_group {
332                let iptr = iptr.offset(ci as isize * c_stride);
333                for koffset in &geometry.pool.patch.standard_layout_data_field {
334                    let iptr = iptr.offset(*koffset);
335                    for x in 0..*geometry.pool.patch.output_shape.get_unchecked(0) {
336                        writer.write(*iptr.offset(x as isize * x_stride));
337                    }
338                }
339            }
340            Ok(())
341        }
342    }
343
344    #[inline(never)]
345    fn padded_2d<'p, T: Copy + Datum>(
346        geometry: &'p ConcreteGeometry,
347        input: &TensorView,
348        pack: &'p mut TensorView,
349        g: usize,
350        pad_value: &Tensor,
351    ) -> TractResult<()> {
352        unsafe {
353            let pad_value = *pad_value.to_scalar_unchecked();
354            let pack = pack.as_slice_mut_unchecked::<T>();
355            let y_stride = geometry.pool.patch.spec.strides[0] as isize;
356            let x_stride = geometry.pool.patch.spec.strides[1] as isize;
357            let shape = &geometry.input_shape_with_n;
358            let y_stride_ptr = y_stride * *shape.h_stride() as isize;
359            let x_stride_ptr = x_stride * *shape.w_stride() as isize;
360            let c_stride_ptr = *shape.c_stride() as isize;
361            let input_heigth = shape.hw_dims()[0] as isize;
362            let input_width = shape.hw_dims()[1] as isize;
363            let kernel_len = geometry.pool.patch.standard_layout_data_field.len();
364            let mut writer =
365                geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n);
366            let iptr = input.as_ptr_unchecked::<T>();
367            let iptr = iptr.add(g * geometry.ci_per_group * shape.c_stride());
368            let output_width = *geometry.pool.patch.output_shape.get_unchecked(1);
369            for ci in 0..geometry.ci_per_group {
370                let iptr = iptr.offset(ci as isize * c_stride_ptr);
371                for kitem in 0..kernel_len {
372                    let dy = *geometry.pool.patch.data_field.as_ptr().offset(kitem as isize * 2);
373                    let dx =
374                        *geometry.pool.patch.data_field.as_ptr().offset(1 + kitem as isize * 2);
375                    let valid_x_start =
376                        Integer::div_ceil(&-dx, &x_stride).max(0).min(output_width as _);
377                    let valid_x_end = Integer::div_ceil(&(input_width - dx), &x_stride)
378                        .max(0)
379                        .min(output_width as _);
380
381                    let iptr = iptr.offset(
382                        *geometry.pool.patch.standard_layout_data_field.get_unchecked(kitem),
383                    );
384                    for yo in 0..*geometry.pool.patch.output_shape.get_unchecked(0) {
385                        let y = yo as isize * y_stride + dy;
386                        let iptr = iptr.offset(yo as isize * y_stride_ptr);
387                        if y >= 0 && y < input_heigth {
388                            Self::padded_2d_invalid_x_loop(
389                                valid_x_start as usize,
390                                pad_value,
391                                &mut writer,
392                            );
393                            Self::padded_2d_valid_x_loop(
394                                valid_x_start,
395                                valid_x_end,
396                                x_stride_ptr,
397                                iptr,
398                                &mut writer,
399                            );
400                            Self::padded_2d_invalid_x_loop(
401                                output_width - valid_x_end as usize,
402                                pad_value,
403                                &mut writer,
404                            );
405                        } else {
406                            Self::padded_2d_invalid_x_loop(output_width, pad_value, &mut writer);
407                        }
408                    }
409                }
410            }
411        }
412        Ok(())
413    }
414
415    #[inline(never)]
416    unsafe fn padded_2d_invalid_x_loop<T: Copy + Datum>(
417        count: usize,
418        pad_value: T,
419        writer: &mut tract_linalg::pack::KOutWriter<T>,
420    ) {
421        for _ in 0..count {
422            writer.write(pad_value);
423        }
424    }
425
426    #[inline(never)]
427    unsafe fn padded_2d_valid_x_loop<T: Copy + Datum>(
428        x_min: isize,
429        x_max: isize,
430        x_stride_ptr: isize,
431        iptr: *const T,
432        writer: &mut tract_linalg::pack::KOutWriter<T>,
433    ) {
434        for x in x_min..x_max {
435            writer.write(unsafe { *iptr.offset(x * x_stride_ptr) });
436        }
437    }
438
439    #[inline(never)]
440    fn valid_2d<'p, T: Copy + Datum>(
441        geometry: &'p ConcreteGeometry,
442        input: &TensorView,
443        pack: &'p mut TensorView,
444        g: usize,
445    ) -> TractResult<()> {
446        unsafe {
447            let pack = pack.as_slice_mut_unchecked::<T>();
448            let shape = &geometry.input_shape_with_n;
449            let y_stride = geometry.pool.patch.spec.strides[0] as isize;
450            let x_stride = geometry.pool.patch.spec.strides[1] as isize;
451            let y_stride_ptr = y_stride * *shape.h_stride() as isize;
452            let x_stride_ptr = x_stride * *shape.w_stride() as isize;
453            let c_stride_ptr = *shape.c_stride() as isize;
454            let mut writer =
455                geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n);
456            let iptr = input.as_ptr_unchecked::<T>();
457            let iptr = iptr.add(g * geometry.ci_per_group * shape.c_stride());
458            for ci in 0..geometry.ci_per_group {
459                let iptr = iptr.offset(ci as isize * c_stride_ptr);
460                for koffset in &geometry.pool.patch.standard_layout_data_field {
461                    let iptr = iptr.offset(*koffset);
462                    for y in 0..*geometry.pool.patch.output_shape.get_unchecked(0) {
463                        let iptr = iptr.offset(y as isize * y_stride_ptr);
464                        for x in 0..*geometry.pool.patch.output_shape.get_unchecked(1) {
465                            writer.write(*iptr.offset(x as isize * x_stride_ptr));
466                        }
467                    }
468                }
469            }
470            Ok(())
471        }
472    }
473}