Skip to main content

tract_core/ops/cnn/conv/
im2col.rs

1use tract_linalg::mmm::{
2    EagerPackedInput, MMMInputFormat, MMMInputValue, MatMatMul, PackedExoticFact,
3    PackedMatrixStorage,
4};
5use tract_linalg::pack::{PackedFormat, PackedI8K4, PackingWriter};
6
7use crate::internal::*;
8use ndarray::prelude::*;
9use num_integer::Integer;
10
11use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry};
12use crate::ops::cnn::{GeometryBound, PoolSpec, ResolveTo};
13use crate::ops::nn::{BaseDataShape, DataFormat, DataShape};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub struct Im2Col {
17    pub pool_spec: PoolSpec,
18    pub group: usize,
19    geometry: GeometryBound<SymbolicGeometry, ConcreteGeometry>,
20}
21
22#[derive(Debug, Clone, Hash, PartialEq, Eq)]
23struct SymbolicGeometry {
24    group: usize,
25    pool_spec: PoolSpec,
26    pool_geometry: PoolGeometry,
27    // The kernel's activation packing: PackedFormat (K-major) or PackedI8K4 (K=4-inner).
28    out_format: Box<dyn MMMInputFormat>,
29    k: usize,
30}
31
32#[derive(Debug, Clone, Hash, PartialEq, Eq)]
33struct ConcreteGeometry {
34    pool: ConcretePoolGeometry,
35    pub n: usize,
36    k: usize,
37    pub out_format: Box<dyn MMMInputFormat>,
38    pub ci_per_group: usize,
39    patcher: Patcher,
40    input_shape_with_n: DataShape,
41    packed_shape: TVec<usize>, // always Batch,Group
42}
43
44impl GeometryBound<SymbolicGeometry, ConcreteGeometry> {
45    pub fn out_format(&self) -> &dyn MMMInputFormat {
46        match self {
47            GeometryBound::Symbolic(s) => &*s.out_format,
48            GeometryBound::Concrete(s) => &*s.out_format,
49        }
50    }
51    pub fn k(&self) -> usize {
52        match self {
53            GeometryBound::Symbolic(s) => s.k,
54            GeometryBound::Concrete(s) => s.k,
55        }
56    }
57}
58
59impl ResolveTo<ConcreteGeometry> for SymbolicGeometry {
60    type Param = [usize];
61    fn resolve(&self, input_full_shape: &[usize]) -> TractResult<ConcreteGeometry> {
62        let pool = self.pool_geometry.to_concrete(input_full_shape)?.into_owned();
63        let patcher = if !pool.patch.padded && pool.patch.rank() == 2 {
64            Patcher::Valid2d
65        } else if pool.patch.rank() == 2 {
66            Patcher::Padded2d
67        } else if !pool.patch.padded && pool.patch.rank() == 1 {
68            Patcher::Valid1d
69        } else {
70            Patcher::Generic
71        };
72        let ci_per_group = pool.input_shape.c_dim() / self.group;
73        let n = pool.output_shape.hw_dims().iter().product();
74        let input_shape_with_n = match self.pool_spec.data_format {
75            DataFormat::HWC => DataFormat::NHWC.from_n_c_hw(
76                1,
77                *pool.input_shape.c(),
78                pool.input_shape.hw_dims(),
79            )?,
80            DataFormat::CHW => DataFormat::NCHW.from_n_c_hw(
81                1,
82                *pool.input_shape.c(),
83                pool.input_shape.hw_dims(),
84            )?,
85            _ => pool.input_shape.clone(),
86        };
87        let packed_shape = Im2Col::packed_shape(&pool.input_shape, self.group)?;
88        Ok(ConcreteGeometry {
89            pool,
90            n,
91            k: self.k,
92            ci_per_group,
93            out_format: self.out_format.clone(),
94            patcher,
95            input_shape_with_n,
96            packed_shape,
97        })
98    }
99}
100
101impl Im2Col {
102    pub fn new(
103        pool_spec: PoolSpec,
104        group: usize,
105        k: usize,
106        input_full_shape: &ShapeFact,
107        mmm: Box<dyn MatMatMul>,
108        packing: usize,
109    ) -> TractResult<Im2Col> {
110        let out_format = dyn_clone::clone_box(&*mmm.packings()[packing].1);
111        let pool_geometry = pool_spec.compute_geo(input_full_shape)?;
112        let geometry: GeometryBound<_, _> =
113            SymbolicGeometry { group, pool_spec: pool_spec.clone(), pool_geometry, out_format, k }
114                .into();
115        let geometry = geometry.optimize_if(input_full_shape.as_concrete())?;
116        Ok(Im2Col { pool_spec, group, geometry })
117    }
118
119    // packed shape is Batch,Group
120    fn packed_shape<D: DimLike>(
121        input_shape: &BaseDataShape<D, TVec<D>>,
122        group: usize,
123    ) -> TractResult<TVec<D>> {
124        let mut output_shape: TVec<D> = tvec!();
125        output_shape.push(input_shape.n().cloned().unwrap_or_else(|| 1.into()));
126        output_shape.push(group.into());
127        Ok(output_shape)
128    }
129}
130
131impl Op for Im2Col {
132    fn name(&self) -> StaticName {
133        "Im2col".into()
134    }
135
136    fn info(&self) -> TractResult<Vec<String>> {
137        Ok(vec![format!("groups:{}", self.group)])
138    }
139
140    op_as_typed_op!();
141}
142
143impl EvalOp for Im2Col {
144    fn is_stateless(&self) -> bool {
145        true
146    }
147
148    fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
149        let geometry = self.geometry.to_concrete(inputs[0].shape())?;
150        unsafe {
151            let mut input = inputs.remove(0).into_tensor();
152            let pad_value: Option<&Tensor> = if inputs.len() > 0 { Some(&inputs[0]) } else { None };
153            if !self.pool_spec.data_format.has_n() {
154                input.insert_axis(0)?;
155            }
156            let dt = input.datum_type();
157            let r = geometry.out_format.r();
158            // Buffer geometry. zero_init for PackedI8K4: the K=4-inner writer skips
159            // the K-padding lanes (k..k_aligned), which SMOPA accumulates — they must
160            // be 0. PackedFormat has no K padding; its mn-padding maps to discarded
161            // output rows, so uninitialized is fine (matches prior behaviour).
162            let (single_panel_len, buf_align, zero_init) =
163                if let Some(pf) = geometry.out_format.downcast_ref::<PackedFormat>() {
164                    (pf.single_panel_len(geometry.k), pf.alignment(), false)
165                } else if let Some(p4) = geometry.out_format.downcast_ref::<PackedI8K4>() {
166                    (p4.single_panel_len(geometry.k), p4.alignment(), true)
167                } else {
168                    bail!("Im2Col: unsupported packing format {:?}", geometry.out_format)
169                };
170            let panel_bytes = single_panel_len * dt.size_of();
171
172            let n_batches = *geometry.input_shape_with_n.n().unwrap_or(&1);
173            let n_groups = self.group;
174            let mut values: Vec<Box<dyn MMMInputValue>> = Vec::with_capacity(n_batches * n_groups);
175
176            for i in 0..n_batches {
177                let input = input.view_at_prefix(&[i])?;
178                for g in 0..n_groups {
179                    let n =
180                        if geometry.pool.output_shape.shape.contains(&0) { 0 } else { geometry.n };
181                    let mut data = Tensor::uninitialized_aligned_dt(
182                        dt,
183                        &[n.divceil(r) * single_panel_len],
184                        buf_align,
185                    )?;
186                    if zero_init {
187                        data.as_bytes_mut().fill(0);
188                    }
189                    if n > 0 {
190                        dispatch_copy_by_size!(Patcher::patch(dt)(
191                            &geometry.patcher,
192                            &geometry,
193                            &input,
194                            &mut data.view_mut(),
195                            g,
196                            pad_value
197                        ))?;
198                    }
199                    values.push(Box::new(EagerPackedInput {
200                        fact: PackedExoticFact {
201                            format: geometry.out_format.clone(),
202                            k: geometry.k,
203                            mn: n.to_dim(),
204                        },
205                        packed: data.into_blob()?.into(),
206                        panel_bytes: if n > 0 { panel_bytes } else { 0 },
207                        mn: n,
208                    }));
209                }
210            }
211
212            let output = PackedMatrixStorage::new_batched(&geometry.packed_shape, values)
213                .into_tensor(input.datum_type());
214            Ok(tvec!(output.into_tvalue()))
215        }
216    }
217}
218
219impl TypedOp for Im2Col {
220    as_op!();
221
222    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
223        let input_shape = self.pool_spec.data_format.shape(inputs[0].shape.to_tvec())?;
224        let output_shape = self.pool_spec.output_shape(&inputs[0].shape)?;
225        let mn = output_shape.hw_dims().iter().product::<TDim>();
226        let pof = PackedExoticFact {
227            format: dyn_clone::clone_box(self.geometry.out_format()),
228            k: self.geometry.k(),
229            mn,
230        };
231        Ok(tvec!(
232            inputs[0]
233                .datum_type
234                .fact(&[input_shape.n().cloned().unwrap_or(1.into()), self.group.into()])
235                .with_exotic_fact(pof)
236        ))
237    }
238
239    fn declutter(
240        &self,
241        model: &TypedModel,
242        node: &TypedNode,
243    ) -> TractResult<Option<TypedModelPatch>> {
244        let input_fact = model.outlet_fact(node.inputs[0])?;
245        if node.inputs.len() == 2
246            && model.outlet_fact(node.inputs[1])?.konst.as_ref().and_then(|t| t.as_uniform())
247                == Some(Tensor::zero_scalar_dt(input_fact.datum_type)?)
248        {
249            Ok(Some(
250                TypedModelPatch::replace_single_op(model, node, &node.inputs[0..1], self.clone())?
251                    .with_context("b0 is zero"),
252            ))
253        } else {
254            Ok(None)
255        }
256    }
257}
258
259#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
260enum Patcher {
261    Generic,
262    Valid1d,
263    Valid2d,
264    Padded2d,
265}
266
267impl Patcher {
268    fn patch<'p, T: Copy + Datum + num_traits::Zero>(
269        &self,
270        geo: &'p ConcreteGeometry,
271        input: &TensorView,
272        pack: &'p mut TensorView,
273        g: usize,
274        pad_value: Option<&Tensor>,
275    ) -> TractResult<()> {
276        // Pick the packing writer for the kernel's output format, then run the
277        // (writer-generic) patcher. PackedFormat keeps the K-major fast path;
278        // PackedI8K4 writes the SMOPA K=4-inner layout in the same single pass.
279        let ptr = unsafe { pack.as_slice_mut_unchecked::<T>().as_mut_ptr() };
280        if let Some(pf) = geo.out_format.downcast_ref::<PackedFormat>() {
281            let mut w = pf.write_with_k_outer(ptr, geo.k, geo.n);
282            self.run::<T, _>(geo, input, g, pad_value, &mut w)
283        } else if let Some(p4) = geo.out_format.downcast_ref::<PackedI8K4>() {
284            let mut w = p4.write_with_k_outer(ptr, geo.k, geo.n);
285            self.run::<T, _>(geo, input, g, pad_value, &mut w)
286        } else {
287            bail!("Im2Col: unsupported packing format {:?}", geo.out_format)
288        }
289    }
290
291    fn run<T: Copy + Datum + num_traits::Zero, W: PackingWriter<T>>(
292        &self,
293        geo: &ConcreteGeometry,
294        input: &TensorView,
295        g: usize,
296        pad_value: Option<&Tensor>,
297        writer: &mut W,
298    ) -> TractResult<()> {
299        match self {
300            Patcher::Valid1d => Self::valid_1d::<T, W>(geo, input, g, writer),
301            Patcher::Valid2d => Self::valid_2d::<T, W>(geo, input, g, writer),
302            Patcher::Padded2d => Self::padded_2d::<T, W>(
303                geo,
304                input,
305                g,
306                pad_value.unwrap_or(&Tensor::zero_scalar::<T>()?),
307                writer,
308            ),
309            _ => Self::generic::<T, W>(
310                geo,
311                input,
312                g,
313                pad_value.unwrap_or(&Tensor::zero_scalar::<T>()?),
314                writer,
315            ),
316        }
317    }
318
319    #[inline(never)]
320    fn generic<T: Copy + Datum, W: PackingWriter<T>>(
321        geometry: &ConcreteGeometry,
322        input: &TensorView,
323        g: usize,
324        pad_value: &Tensor,
325        writer: &mut W,
326    ) -> TractResult<()> {
327        unsafe {
328            let pad_value = *pad_value.to_scalar_unchecked();
329            let mut mega_matrix = Tensor::uninitialized::<T>(&[geometry.k, geometry.n])?;
330            let mut mega_matrix_view = mega_matrix.to_array_view_mut_unchecked::<T>();
331            let ptr = input.as_ptr_unchecked::<T>();
332            let ptr = ptr.add(geometry.input_shape_with_n.c_stride() * (g * geometry.ci_per_group));
333            for (spatial, mut col) in ndarray::indices(&*geometry.pool.patch.output_shape)
334                .into_iter()
335                .zip(mega_matrix_view.axis_iter_mut(Axis(1)))
336            {
337                let mut col = col.iter_mut();
338                for ci in 0..geometry.ci_per_group {
339                    let ptr = ptr.add(geometry.input_shape_with_n.c_stride() * ci);
340                    for v in geometry.pool.patch.at(spatial.slice()) {
341                        *col.next().expect("geometry error in conv") =
342                            v.map(|o| *ptr.offset(o)).unwrap_or(pad_value);
343                    }
344                }
345            }
346            // mega_matrix is [k, n] (k-major); feed K-outer to the writer, which
347            // lays out the kernel's packing (K-major for PackedFormat, K=4-inner
348            // for PackedI8K4) — byte-identical to PackedFormat::pack for the former.
349            let mv = mega_matrix.as_slice_unchecked::<T>();
350            for kk in 0..geometry.k {
351                writer.write_slice(&mv[kk * geometry.n..(kk + 1) * geometry.n]);
352            }
353            Ok(())
354        }
355    }
356
357    #[inline(never)]
358    fn valid_1d<T: Copy + Datum, W: PackingWriter<T>>(
359        geometry: &ConcreteGeometry,
360        input: &TensorView,
361        g: usize,
362        writer: &mut W,
363    ) -> TractResult<()> {
364        unsafe {
365            let x_stride = *geometry.input_shape_with_n.h_stride() as isize
366                * geometry.pool.patch.spec.strides[0] as isize;
367            let c_stride = *geometry.input_shape_with_n.c_stride() as isize;
368            let iptr = input.as_ptr_unchecked::<T>();
369            let iptr = iptr.add(g * geometry.ci_per_group * geometry.input_shape_with_n.c_stride());
370            let output_x = *geometry.pool.patch.output_shape.get_unchecked(0);
371            // Fast path: stride-1 contiguous read along x. Replaces the
372            // per-element pointer-arithmetic loop with a single write_slice
373            // (memcpy when the slice fits in the current panel).
374            // Byte-identical to the slow path (write_slice's contract).
375            let contiguous_x = x_stride == 1;
376            for ci in 0..geometry.ci_per_group {
377                let iptr = iptr.offset(ci as isize * c_stride);
378                for koffset in &geometry.pool.patch.standard_layout_data_field {
379                    let iptr = iptr.offset(*koffset);
380                    if contiguous_x {
381                        let row = std::slice::from_raw_parts(iptr, output_x);
382                        writer.write_slice(row);
383                    } else {
384                        // Hoist multiplication out of inner loop.
385                        let mut iptr_x = iptr;
386                        for _ in 0..output_x {
387                            writer.write(*iptr_x);
388                            iptr_x = iptr_x.offset(x_stride);
389                        }
390                    }
391                }
392            }
393            Ok(())
394        }
395    }
396
397    #[inline(never)]
398    fn padded_2d<T: Copy + Datum, W: PackingWriter<T>>(
399        geometry: &ConcreteGeometry,
400        input: &TensorView,
401        g: usize,
402        pad_value: &Tensor,
403        writer: &mut W,
404    ) -> TractResult<()> {
405        unsafe {
406            let pad_value = *pad_value.to_scalar_unchecked();
407            let y_stride = geometry.pool.patch.spec.strides[0] as isize;
408            let x_stride = geometry.pool.patch.spec.strides[1] as isize;
409            let shape = &geometry.input_shape_with_n;
410            let y_stride_ptr = y_stride * *shape.h_stride() as isize;
411            let x_stride_ptr = x_stride * *shape.w_stride() as isize;
412            let c_stride_ptr = *shape.c_stride() as isize;
413            let input_heigth = shape.hw_dims()[0] as isize;
414            let input_width = shape.hw_dims()[1] as isize;
415            let kernel_len = geometry.pool.patch.standard_layout_data_field.len();
416            let iptr = input.as_ptr_unchecked::<T>();
417            let iptr = iptr.add(g * geometry.ci_per_group * shape.c_stride());
418            let output_width = *geometry.pool.patch.output_shape.get_unchecked(1);
419            for ci in 0..geometry.ci_per_group {
420                let iptr = iptr.offset(ci as isize * c_stride_ptr);
421                for kitem in 0..kernel_len {
422                    let dy = *geometry.pool.patch.data_field.as_ptr().offset(kitem as isize * 2);
423                    let dx =
424                        *geometry.pool.patch.data_field.as_ptr().offset(1 + kitem as isize * 2);
425                    let valid_x_start =
426                        Integer::div_ceil(&-dx, &x_stride).max(0).min(output_width as _);
427                    let valid_x_end = Integer::div_ceil(&(input_width - dx), &x_stride)
428                        .max(0)
429                        .min(output_width as _);
430
431                    let iptr = iptr.offset(
432                        *geometry.pool.patch.standard_layout_data_field.get_unchecked(kitem),
433                    );
434                    for yo in 0..*geometry.pool.patch.output_shape.get_unchecked(0) {
435                        let y = yo as isize * y_stride + dy;
436                        let iptr = iptr.offset(yo as isize * y_stride_ptr);
437                        if y >= 0 && y < input_heigth {
438                            Self::padded_2d_invalid_x_loop(
439                                valid_x_start as usize,
440                                pad_value,
441                                &mut *writer,
442                            );
443                            Self::padded_2d_valid_x_loop(
444                                valid_x_start,
445                                valid_x_end,
446                                x_stride_ptr,
447                                iptr,
448                                &mut *writer,
449                            );
450                            Self::padded_2d_invalid_x_loop(
451                                output_width - valid_x_end as usize,
452                                pad_value,
453                                &mut *writer,
454                            );
455                        } else {
456                            Self::padded_2d_invalid_x_loop(output_width, pad_value, &mut *writer);
457                        }
458                    }
459                }
460            }
461        }
462        Ok(())
463    }
464
465    #[inline(never)]
466    unsafe fn padded_2d_invalid_x_loop<T: Copy + Datum, W: PackingWriter<T>>(
467        count: usize,
468        pad_value: T,
469        writer: &mut W,
470    ) {
471        for _ in 0..count {
472            writer.write(pad_value);
473        }
474    }
475
476    #[inline(never)]
477    unsafe fn padded_2d_valid_x_loop<T: Copy + Datum, W: PackingWriter<T>>(
478        x_min: isize,
479        x_max: isize,
480        x_stride_ptr: isize,
481        iptr: *const T,
482        writer: &mut W,
483    ) {
484        // Fast path: x_stride_ptr == 1 means consecutive x values are at
485        // consecutive memory addresses, so the inner loop is a contiguous
486        // slice write — byte-identical to the per-element loop.
487        if x_stride_ptr == 1 && x_max > x_min {
488            unsafe {
489                let row = std::slice::from_raw_parts(iptr.offset(x_min), (x_max - x_min) as usize);
490                writer.write_slice(row);
491            }
492        } else {
493            for x in x_min..x_max {
494                writer.write(unsafe { *iptr.offset(x * x_stride_ptr) });
495            }
496        }
497    }
498
499    #[inline(never)]
500    fn valid_2d<T: Copy + Datum, W: PackingWriter<T>>(
501        geometry: &ConcreteGeometry,
502        input: &TensorView,
503        g: usize,
504        writer: &mut W,
505    ) -> TractResult<()> {
506        unsafe {
507            let shape = &geometry.input_shape_with_n;
508            let y_stride = geometry.pool.patch.spec.strides[0] as isize;
509            let x_stride = geometry.pool.patch.spec.strides[1] as isize;
510            let y_stride_ptr = y_stride * *shape.h_stride() as isize;
511            let x_stride_ptr = x_stride * *shape.w_stride() as isize;
512            let c_stride_ptr = *shape.c_stride() as isize;
513            let iptr = input.as_ptr_unchecked::<T>();
514            let iptr = iptr.add(g * geometry.ci_per_group * shape.c_stride());
515            let output_y = *geometry.pool.patch.output_shape.get_unchecked(0);
516            let output_x = *geometry.pool.patch.output_shape.get_unchecked(1);
517            // Fast path: stride-1 contiguous reads along x within each y-row.
518            // Each y-row becomes a single write_slice (memcpy when the slice
519            // fits in the current panel). Byte-identical to the slow path.
520            let contiguous_x = x_stride_ptr == 1;
521            for ci in 0..geometry.ci_per_group {
522                let iptr = iptr.offset(ci as isize * c_stride_ptr);
523                for koffset in &geometry.pool.patch.standard_layout_data_field {
524                    let iptr = iptr.offset(*koffset);
525                    let mut iptr_y = iptr;
526                    for _ in 0..output_y {
527                        if contiguous_x {
528                            let row = std::slice::from_raw_parts(iptr_y, output_x);
529                            writer.write_slice(row);
530                        } else {
531                            // Hoist x multiplication out of inner loop.
532                            let mut iptr_x = iptr_y;
533                            for _ in 0..output_x {
534                                writer.write(*iptr_x);
535                                iptr_x = iptr_x.offset(x_stride_ptr);
536                            }
537                        }
538                        iptr_y = iptr_y.offset(y_stride_ptr);
539                    }
540                }
541            }
542            Ok(())
543        }
544    }
545}