1use tract_linalg::mmm::{
2 EagerPackedInput, MMMInputValue, MatMatMul, PackedExoticFact, PackedMatrixStorage,
3};
4use tract_linalg::pack::{PackedFormat, PackingWriter};
5
6use crate::internal::*;
7use ndarray::prelude::*;
8use num_integer::Integer;
9
10use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry};
11use crate::ops::cnn::{GeometryBound, PoolSpec, ResolveTo};
12use crate::ops::nn::{BaseDataShape, DataFormat, DataShape};
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub struct Im2Col {
16 pub pool_spec: PoolSpec,
17 pub group: usize,
18 geometry: GeometryBound<SymbolicGeometry, ConcreteGeometry>,
19}
20
21#[derive(Debug, Clone, Hash, PartialEq, Eq)]
22struct SymbolicGeometry {
23 group: usize,
24 pool_spec: PoolSpec,
25 pool_geometry: PoolGeometry,
26 b_pack: PackedFormat,
27 k: usize,
28}
29
30#[derive(Debug, Clone, Hash, PartialEq, Eq)]
31struct ConcreteGeometry {
32 pool: ConcretePoolGeometry,
33 pub n: usize,
34 k: usize,
35 pub b_pack: PackedFormat,
36 pub ci_per_group: usize,
37 patcher: Patcher,
38 input_shape_with_n: DataShape,
39 packed_shape: TVec<usize>, }
41
42impl GeometryBound<SymbolicGeometry, ConcreteGeometry> {
43 pub fn b_pack(&self) -> &PackedFormat {
44 match self {
45 GeometryBound::Symbolic(s) => &s.b_pack,
46 GeometryBound::Concrete(s) => &s.b_pack,
47 }
48 }
49 pub fn k(&self) -> usize {
50 match self {
51 GeometryBound::Symbolic(s) => s.k,
52 GeometryBound::Concrete(s) => s.k,
53 }
54 }
55}
56
57impl ResolveTo<ConcreteGeometry> for SymbolicGeometry {
58 type Param = [usize];
59 fn resolve(&self, input_full_shape: &[usize]) -> TractResult<ConcreteGeometry> {
60 let pool = self.pool_geometry.to_concrete(input_full_shape)?.into_owned();
61 let patcher = if !pool.patch.padded && pool.patch.rank() == 2 {
62 Patcher::Valid2d
63 } else if pool.patch.rank() == 2 {
64 Patcher::Padded2d
65 } else if !pool.patch.padded && pool.patch.rank() == 1 {
66 Patcher::Valid1d
67 } else {
68 Patcher::Generic
69 };
70 let ci_per_group = pool.input_shape.c_dim() / self.group;
71 let n = pool.output_shape.hw_dims().iter().product();
72 let input_shape_with_n = match self.pool_spec.data_format {
73 DataFormat::HWC => DataFormat::NHWC.from_n_c_hw(
74 1,
75 *pool.input_shape.c(),
76 pool.input_shape.hw_dims(),
77 )?,
78 DataFormat::CHW => DataFormat::NCHW.from_n_c_hw(
79 1,
80 *pool.input_shape.c(),
81 pool.input_shape.hw_dims(),
82 )?,
83 _ => pool.input_shape.clone(),
84 };
85 let packed_shape = Im2Col::packed_shape(&pool.input_shape, self.group)?;
86 Ok(ConcreteGeometry {
87 pool,
88 n,
89 k: self.k,
90 ci_per_group,
91 b_pack: self.b_pack.clone(),
92 patcher,
93 input_shape_with_n,
94 packed_shape,
95 })
96 }
97}
98
99impl Im2Col {
100 pub fn new(
101 pool_spec: PoolSpec,
102 group: usize,
103 k: usize,
104 input_full_shape: &ShapeFact,
105 mmm: Box<dyn MatMatMul>,
106 packing: usize,
107 ) -> TractResult<Im2Col> {
108 let b_pack = mmm.packings()[packing]
109 .1
110 .downcast_ref::<PackedFormat>()
111 .context("Im2Col expects regular packed format")?
112 .clone();
113
114 let pool_geometry = pool_spec.compute_geo(input_full_shape)?;
115 let geometry: GeometryBound<_, _> =
116 SymbolicGeometry { group, pool_spec: pool_spec.clone(), pool_geometry, b_pack, k }
117 .into();
118 let geometry = geometry.optimize_if(input_full_shape.as_concrete())?;
119 Ok(Im2Col { pool_spec, group, geometry })
120 }
121
122 fn packed_shape<D: DimLike>(
124 input_shape: &BaseDataShape<D, TVec<D>>,
125 group: usize,
126 ) -> TractResult<TVec<D>> {
127 let mut output_shape: TVec<D> = tvec!();
128 output_shape.push(input_shape.n().cloned().unwrap_or_else(|| 1.into()));
129 output_shape.push(group.into());
130 Ok(output_shape)
131 }
132}
133
134impl Op for Im2Col {
135 fn name(&self) -> StaticName {
136 "Im2col".into()
137 }
138
139 fn info(&self) -> TractResult<Vec<String>> {
140 Ok(vec![format!("groups:{}", self.group)])
141 }
142
143 op_as_typed_op!();
144}
145
146impl EvalOp for Im2Col {
147 fn is_stateless(&self) -> bool {
148 true
149 }
150
151 fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
152 let geometry = self.geometry.to_concrete(inputs[0].shape())?;
153 unsafe {
154 let mut input = inputs.remove(0).into_tensor();
155 let pad_value: Option<&Tensor> = if inputs.len() > 0 { Some(&inputs[0]) } else { None };
156 if !self.pool_spec.data_format.has_n() {
157 input.insert_axis(0)?;
158 }
159 let panel_bytes =
160 geometry.b_pack.single_panel_len(geometry.k) * input.datum_type().size_of();
161
162 let n_batches = *geometry.input_shape_with_n.n().unwrap_or(&1);
163 let n_groups = self.group;
164 let mut values: Vec<Box<dyn MMMInputValue>> = Vec::with_capacity(n_batches * n_groups);
165
166 for i in 0..n_batches {
167 let input = input.view_at_prefix(&[i])?;
168 for g in 0..n_groups {
169 let n =
170 if geometry.pool.output_shape.shape.contains(&0) { 0 } else { geometry.n };
171 let mut data = Tensor::uninitialized_aligned_dt(
172 input.datum_type(),
173 &[geometry.b_pack.len(geometry.k, n)],
174 geometry.b_pack.alignment(),
175 )?;
176 if n > 0 {
177 dispatch_copy_by_size!(Patcher::patch(input.datum_type())(
178 &geometry.patcher,
179 &geometry,
180 &input,
181 &mut data.view_mut(),
182 g,
183 pad_value
184 ))?;
185 }
186 values.push(Box::new(EagerPackedInput {
187 fact: PackedExoticFact {
188 format: Box::new(geometry.b_pack.clone()),
189 k: geometry.k,
190 mn: n.to_dim(),
191 },
192 packed: data.into_blob()?.into(),
193 panel_bytes: if n > 0 { panel_bytes } else { 0 },
194 mn: n,
195 }));
196 }
197 }
198
199 let output = PackedMatrixStorage::new_batched(&geometry.packed_shape, values)
200 .into_tensor(input.datum_type());
201 Ok(tvec!(output.into_tvalue()))
202 }
203 }
204}
205
206impl TypedOp for Im2Col {
207 as_op!();
208
209 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
210 let input_shape = self.pool_spec.data_format.shape(inputs[0].shape.to_tvec())?;
211 let output_shape = self.pool_spec.output_shape(&inputs[0].shape)?;
212 let mn = output_shape.hw_dims().iter().product::<TDim>();
213 let pof = PackedExoticFact {
214 format: Box::new(self.geometry.b_pack().clone()),
215 k: self.geometry.k(),
216 mn,
217 };
218 Ok(tvec!(
219 inputs[0]
220 .datum_type
221 .fact(&[input_shape.n().cloned().unwrap_or(1.into()), self.group.into()])
222 .with_exotic_fact(pof)
223 ))
224 }
225
226 fn declutter(
227 &self,
228 model: &TypedModel,
229 node: &TypedNode,
230 ) -> TractResult<Option<TypedModelPatch>> {
231 let input_fact = model.outlet_fact(node.inputs[0])?;
232 if node.inputs.len() == 2
233 && model.outlet_fact(node.inputs[1])?.konst.as_ref().and_then(|t| t.as_uniform())
234 == Some(Tensor::zero_scalar_dt(input_fact.datum_type)?)
235 {
236 Ok(Some(
237 TypedModelPatch::replace_single_op(model, node, &node.inputs[0..1], self.clone())?
238 .with_context("b0 is zero"),
239 ))
240 } else {
241 Ok(None)
242 }
243 }
244}
245
246#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
247enum Patcher {
248 Generic,
249 Valid1d,
250 Valid2d,
251 Padded2d,
252}
253
254impl Patcher {
255 fn patch<'p, T: Copy + Datum + num_traits::Zero>(
256 &self,
257 geo: &'p ConcreteGeometry,
258 input: &TensorView,
259 pack: &'p mut TensorView,
260 g: usize,
261 pad_value: Option<&Tensor>,
262 ) -> TractResult<()> {
263 match self {
264 Patcher::Valid1d => Self::valid_1d::<T>(geo, input, pack, g),
265 Patcher::Valid2d => Self::valid_2d::<T>(geo, input, pack, g),
266 Patcher::Padded2d => Self::padded_2d::<T>(
267 geo,
268 input,
269 pack,
270 g,
271 pad_value.unwrap_or(&Tensor::zero_scalar::<T>()?),
272 ),
273 _ => Self::generic::<T>(
274 geo,
275 input,
276 pack,
277 g,
278 pad_value.unwrap_or(&Tensor::zero_scalar::<T>()?),
279 ),
280 }
281 }
282
283 #[inline(never)]
284 fn generic<'p, T: Copy + Datum>(
285 geometry: &'p ConcreteGeometry,
286 input: &TensorView,
287 pack: &'p mut TensorView,
288 g: usize,
289 pad_value: &Tensor,
290 ) -> TractResult<()> {
291 unsafe {
292 let pad_value = *pad_value.to_scalar_unchecked();
293 let mut mega_matrix = Tensor::uninitialized::<T>(&[geometry.k, geometry.n])?;
294 let mut mega_matrix_view = mega_matrix.to_array_view_mut_unchecked::<T>();
295 let ptr = input.as_ptr_unchecked::<T>();
296 let ptr = ptr.add(geometry.input_shape_with_n.c_stride() * (g * geometry.ci_per_group));
297 for (spatial, mut col) in ndarray::indices(&*geometry.pool.patch.output_shape)
298 .into_iter()
299 .zip(mega_matrix_view.axis_iter_mut(Axis(1)))
300 {
301 let mut col = col.iter_mut();
302 for ci in 0..geometry.ci_per_group {
303 let ptr = ptr.add(geometry.input_shape_with_n.c_stride() * ci);
304 for v in geometry.pool.patch.at(spatial.slice()) {
305 *col.next().expect("geometry error in conv") =
306 v.map(|o| *ptr.offset(o)).unwrap_or(pad_value);
307 }
308 }
309 }
310 geometry.b_pack.pack(pack, mega_matrix.view(), 0, 1);
311 Ok(())
312 }
313 }
314
315 #[inline(never)]
316 fn valid_1d<'p, T: Copy + Datum>(
317 geometry: &'p ConcreteGeometry,
318 input: &TensorView,
319 pack: &'p mut TensorView,
320 g: usize,
321 ) -> TractResult<()> {
322 unsafe {
323 let x_stride = *geometry.input_shape_with_n.h_stride() as isize
324 * geometry.pool.patch.spec.strides[0] as isize;
325 let c_stride = *geometry.input_shape_with_n.c_stride() as isize;
326 let pack = pack.as_slice_mut_unchecked::<T>();
327 let mut writer =
328 geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n);
329 let iptr = input.as_ptr_unchecked::<T>();
330 let iptr = iptr.add(g * geometry.ci_per_group * geometry.input_shape_with_n.c_stride());
331 for ci in 0..geometry.ci_per_group {
332 let iptr = iptr.offset(ci as isize * c_stride);
333 for koffset in &geometry.pool.patch.standard_layout_data_field {
334 let iptr = iptr.offset(*koffset);
335 for x in 0..*geometry.pool.patch.output_shape.get_unchecked(0) {
336 writer.write(*iptr.offset(x as isize * x_stride));
337 }
338 }
339 }
340 Ok(())
341 }
342 }
343
344 #[inline(never)]
345 fn padded_2d<'p, T: Copy + Datum>(
346 geometry: &'p ConcreteGeometry,
347 input: &TensorView,
348 pack: &'p mut TensorView,
349 g: usize,
350 pad_value: &Tensor,
351 ) -> TractResult<()> {
352 unsafe {
353 let pad_value = *pad_value.to_scalar_unchecked();
354 let pack = pack.as_slice_mut_unchecked::<T>();
355 let y_stride = geometry.pool.patch.spec.strides[0] as isize;
356 let x_stride = geometry.pool.patch.spec.strides[1] as isize;
357 let shape = &geometry.input_shape_with_n;
358 let y_stride_ptr = y_stride * *shape.h_stride() as isize;
359 let x_stride_ptr = x_stride * *shape.w_stride() as isize;
360 let c_stride_ptr = *shape.c_stride() as isize;
361 let input_heigth = shape.hw_dims()[0] as isize;
362 let input_width = shape.hw_dims()[1] as isize;
363 let kernel_len = geometry.pool.patch.standard_layout_data_field.len();
364 let mut writer =
365 geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n);
366 let iptr = input.as_ptr_unchecked::<T>();
367 let iptr = iptr.add(g * geometry.ci_per_group * shape.c_stride());
368 let output_width = *geometry.pool.patch.output_shape.get_unchecked(1);
369 for ci in 0..geometry.ci_per_group {
370 let iptr = iptr.offset(ci as isize * c_stride_ptr);
371 for kitem in 0..kernel_len {
372 let dy = *geometry.pool.patch.data_field.as_ptr().offset(kitem as isize * 2);
373 let dx =
374 *geometry.pool.patch.data_field.as_ptr().offset(1 + kitem as isize * 2);
375 let valid_x_start =
376 Integer::div_ceil(&-dx, &x_stride).max(0).min(output_width as _);
377 let valid_x_end = Integer::div_ceil(&(input_width - dx), &x_stride)
378 .max(0)
379 .min(output_width as _);
380
381 let iptr = iptr.offset(
382 *geometry.pool.patch.standard_layout_data_field.get_unchecked(kitem),
383 );
384 for yo in 0..*geometry.pool.patch.output_shape.get_unchecked(0) {
385 let y = yo as isize * y_stride + dy;
386 let iptr = iptr.offset(yo as isize * y_stride_ptr);
387 if y >= 0 && y < input_heigth {
388 Self::padded_2d_invalid_x_loop(
389 valid_x_start as usize,
390 pad_value,
391 &mut writer,
392 );
393 Self::padded_2d_valid_x_loop(
394 valid_x_start,
395 valid_x_end,
396 x_stride_ptr,
397 iptr,
398 &mut writer,
399 );
400 Self::padded_2d_invalid_x_loop(
401 output_width - valid_x_end as usize,
402 pad_value,
403 &mut writer,
404 );
405 } else {
406 Self::padded_2d_invalid_x_loop(output_width, pad_value, &mut writer);
407 }
408 }
409 }
410 }
411 }
412 Ok(())
413 }
414
415 #[inline(never)]
416 unsafe fn padded_2d_invalid_x_loop<T: Copy + Datum>(
417 count: usize,
418 pad_value: T,
419 writer: &mut tract_linalg::pack::KOutWriter<T>,
420 ) {
421 for _ in 0..count {
422 writer.write(pad_value);
423 }
424 }
425
426 #[inline(never)]
427 unsafe fn padded_2d_valid_x_loop<T: Copy + Datum>(
428 x_min: isize,
429 x_max: isize,
430 x_stride_ptr: isize,
431 iptr: *const T,
432 writer: &mut tract_linalg::pack::KOutWriter<T>,
433 ) {
434 for x in x_min..x_max {
435 writer.write(unsafe { *iptr.offset(x * x_stride_ptr) });
436 }
437 }
438
439 #[inline(never)]
440 fn valid_2d<'p, T: Copy + Datum>(
441 geometry: &'p ConcreteGeometry,
442 input: &TensorView,
443 pack: &'p mut TensorView,
444 g: usize,
445 ) -> TractResult<()> {
446 unsafe {
447 let pack = pack.as_slice_mut_unchecked::<T>();
448 let shape = &geometry.input_shape_with_n;
449 let y_stride = geometry.pool.patch.spec.strides[0] as isize;
450 let x_stride = geometry.pool.patch.spec.strides[1] as isize;
451 let y_stride_ptr = y_stride * *shape.h_stride() as isize;
452 let x_stride_ptr = x_stride * *shape.w_stride() as isize;
453 let c_stride_ptr = *shape.c_stride() as isize;
454 let mut writer =
455 geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n);
456 let iptr = input.as_ptr_unchecked::<T>();
457 let iptr = iptr.add(g * geometry.ci_per_group * shape.c_stride());
458 for ci in 0..geometry.ci_per_group {
459 let iptr = iptr.offset(ci as isize * c_stride_ptr);
460 for koffset in &geometry.pool.patch.standard_layout_data_field {
461 let iptr = iptr.offset(*koffset);
462 for y in 0..*geometry.pool.patch.output_shape.get_unchecked(0) {
463 let iptr = iptr.offset(y as isize * y_stride_ptr);
464 for x in 0..*geometry.pool.patch.output_shape.get_unchecked(1) {
465 writer.write(*iptr.offset(x as isize * x_stride_ptr));
466 }
467 }
468 }
469 }
470 Ok(())
471 }
472 }
473}