1use tract_linalg::mmm::{EagerPackedInput, MMMInputValue, MatMatMul, PackedOpaqueFact};
2use tract_linalg::pack::{PackedFormat, PackingWriter};
3
4use crate::internal::*;
5use ndarray::prelude::*;
6use num_integer::Integer;
7
8use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry};
9use crate::ops::cnn::{GeometryBound, PoolSpec, ResolveTo};
10use crate::ops::nn::{BaseDataShape, DataFormat, DataShape};
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
13pub struct Im2Col {
14 pub pool_spec: PoolSpec,
15 pub group: usize,
16 geometry: GeometryBound<SymbolicGeometry, ConcreteGeometry>,
17}
18
19#[derive(Debug, Clone, Hash, PartialEq, Eq)]
20struct SymbolicGeometry {
21 group: usize,
22 pool_spec: PoolSpec,
23 pool_geometry: PoolGeometry,
24 b_pack: PackedFormat,
25 k: usize,
26}
27
28#[derive(Debug, Clone, Hash, PartialEq, Eq)]
29struct ConcreteGeometry {
30 pool: ConcretePoolGeometry,
31 pub n: usize,
32 k: usize,
33 pub b_pack: PackedFormat,
34 pub ci_per_group: usize,
35 patcher: Patcher,
36 input_shape_with_n: DataShape,
37 packed_shape: TVec<usize>, }
39
40impl GeometryBound<SymbolicGeometry, ConcreteGeometry> {
41 pub fn b_pack(&self) -> &PackedFormat {
42 match self {
43 GeometryBound::Symbolic(s) => &s.b_pack,
44 GeometryBound::Concrete(s) => &s.b_pack,
45 }
46 }
47 pub fn k(&self) -> usize {
48 match self {
49 GeometryBound::Symbolic(s) => s.k,
50 GeometryBound::Concrete(s) => s.k,
51 }
52 }
53}
54
55impl ResolveTo<ConcreteGeometry> for SymbolicGeometry {
56 type Param = [usize];
57 fn resolve(&self, input_full_shape: &[usize]) -> TractResult<ConcreteGeometry> {
58 let pool = self.pool_geometry.to_concrete(input_full_shape)?.into_owned();
59 let patcher = if !pool.patch.padded && pool.patch.rank() == 2 {
60 Patcher::Valid2d
61 } else if pool.patch.rank() == 2 {
62 Patcher::Padded2d
63 } else if !pool.patch.padded && pool.patch.rank() == 1 {
64 Patcher::Valid1d
65 } else {
66 Patcher::Generic
67 };
68 let ci_per_group = pool.input_shape.c_dim() / self.group;
69 let n = pool.output_shape.hw_dims().iter().product();
70 let input_shape_with_n = match self.pool_spec.data_format {
71 DataFormat::HWC => DataFormat::NHWC.from_n_c_hw(
72 1,
73 *pool.input_shape.c(),
74 pool.input_shape.hw_dims(),
75 )?,
76 DataFormat::CHW => DataFormat::NCHW.from_n_c_hw(
77 1,
78 *pool.input_shape.c(),
79 pool.input_shape.hw_dims(),
80 )?,
81 _ => pool.input_shape.clone(),
82 };
83 let packed_shape = Im2Col::packed_shape(&pool.input_shape, self.group)?;
84 Ok(ConcreteGeometry {
85 pool,
86 n,
87 k: self.k,
88 ci_per_group,
89 b_pack: self.b_pack.clone(),
90 patcher,
91 input_shape_with_n,
92 packed_shape,
93 })
94 }
95}
96
97impl Im2Col {
98 pub fn new(
99 pool_spec: PoolSpec,
100 group: usize,
101 k: usize,
102 input_full_shape: &ShapeFact,
103 mmm: Box<dyn MatMatMul>,
104 packing: usize,
105 ) -> TractResult<Im2Col> {
106 let b_pack = mmm.packings()[packing]
107 .1
108 .downcast_ref::<PackedFormat>()
109 .context("Im2Col expects regular packed format")?
110 .clone();
111
112 let pool_geometry = pool_spec.compute_geo(input_full_shape)?;
113 let geometry: GeometryBound<_, _> =
114 SymbolicGeometry { group, pool_spec: pool_spec.clone(), pool_geometry, b_pack, k }
115 .into();
116 let geometry = geometry.optimize_if(input_full_shape.as_concrete())?;
117 Ok(Im2Col { pool_spec, group, geometry })
118 }
119
120 fn packed_shape<D: DimLike>(
122 input_shape: &BaseDataShape<D, TVec<D>>,
123 group: usize,
124 ) -> TractResult<TVec<D>> {
125 let mut output_shape: TVec<D> = tvec!();
126 output_shape.push(input_shape.n().cloned().unwrap_or_else(|| 1.into()));
127 output_shape.push(group.into());
128 Ok(output_shape)
129 }
130}
131
132impl Op for Im2Col {
133 fn name(&self) -> StaticName {
134 "Im2col".into()
135 }
136
137 fn info(&self) -> TractResult<Vec<String>> {
138 Ok(vec![format!("groups:{}", self.group)])
139 }
140
141 impl_op_same_as!();
142 op_as_typed_op!();
143}
144
145impl EvalOp for Im2Col {
146 fn is_stateless(&self) -> bool {
147 true
148 }
149
150 fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
151 let geometry = self.geometry.to_concrete(inputs[0].shape())?;
152 unsafe {
153 let mut input = inputs.remove(0).into_tensor();
154 let pad_value: Option<&Tensor> = if inputs.len() > 0 { Some(&inputs[0]) } else { None };
155 let mut output = Tensor::uninitialized::<Opaque>(&geometry.packed_shape)?;
156 if !self.pool_spec.data_format.has_n() {
157 input.insert_axis(0)?;
158 }
159 let mut output_dense = output.try_as_dense_mut()?;
160 let mut output_view = output_dense.to_array_view_mut::<Opaque>()?;
161 let panel_bytes =
162 geometry.b_pack.single_panel_len(geometry.k) * input.datum_type().size_of();
163
164 if !geometry.pool.output_shape.shape.contains(&0) {
167 for i in 0..*geometry.input_shape_with_n.n().unwrap_or(&1) {
168 let input = input.view_at_prefix(&[i])?;
169 for g in 0..self.group {
170 let mut data = Tensor::uninitialized_aligned_dt(
171 input.datum_type(),
172 &[geometry.b_pack.len(geometry.k, geometry.n)],
173 geometry.b_pack.alignment(),
174 )?;
175 dispatch_copy_by_size!(Patcher::patch(input.datum_type())(
176 &geometry.patcher,
177 &geometry,
178 &input,
179 &mut data.view_mut(),
180 g,
181 pad_value
182 ))?;
183 let input: Box<dyn MMMInputValue> = Box::new(EagerPackedInput {
184 fact: PackedOpaqueFact {
185 format: Box::new(geometry.b_pack.clone()),
186 k: geometry.k,
187 mn: geometry.n.to_dim(),
188 },
189 packed: data.into_blob()?.into(),
190 panel_bytes,
191 mn: geometry.n,
192 });
193 output_view[[i, g]] = input.into();
194 }
195 }
196 }
197 Ok(tvec!(output.into_tvalue()))
198 }
199 }
200}
201
202impl TypedOp for Im2Col {
203 as_op!();
204
205 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
206 let input_shape = self.pool_spec.data_format.shape(inputs[0].shape.to_tvec())?;
207 let output_shape = self.pool_spec.output_shape(&inputs[0].shape)?;
208 let mn = output_shape.hw_dims().iter().product::<TDim>();
209 let pof = PackedOpaqueFact {
210 format: Box::new(self.geometry.b_pack().clone()),
211 k: self.geometry.k(),
212 mn,
213 };
214 Ok(tvec!(
215 Opaque::fact(&[input_shape.n().cloned().unwrap_or(1.into()), self.group.into()])
216 .with_opaque_fact(pof)
217 ))
218 }
219
220 fn declutter(
221 &self,
222 model: &TypedModel,
223 node: &TypedNode,
224 ) -> TractResult<Option<TypedModelPatch>> {
225 let input_fact = model.outlet_fact(node.inputs[0])?;
226 if node.inputs.len() == 2
227 && model.outlet_fact(node.inputs[1])?.konst.as_ref().and_then(|t| t.as_uniform())
228 == Some(Tensor::zero_scalar_dt(input_fact.datum_type)?)
229 {
230 Ok(Some(
231 TypedModelPatch::replace_single_op(model, node, &node.inputs[0..1], self.clone())?
232 .with_context("b0 is zero"),
233 ))
234 } else {
235 Ok(None)
236 }
237 }
238}
239
240#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
241enum Patcher {
242 Generic,
243 Valid1d,
244 Valid2d,
245 Padded2d,
246}
247
248impl Patcher {
249 fn patch<'p, T: Copy + Datum + num_traits::Zero>(
250 &self,
251 geo: &'p ConcreteGeometry,
252 input: &TensorView,
253 pack: &'p mut TensorView,
254 g: usize,
255 pad_value: Option<&Tensor>,
256 ) -> TractResult<()> {
257 match self {
258 Patcher::Valid1d => Self::valid_1d::<T>(geo, input, pack, g),
259 Patcher::Valid2d => Self::valid_2d::<T>(geo, input, pack, g),
260 Patcher::Padded2d => Self::padded_2d::<T>(
261 geo,
262 input,
263 pack,
264 g,
265 pad_value.unwrap_or(&Tensor::zero_scalar::<T>()?),
266 ),
267 _ => Self::generic::<T>(
268 geo,
269 input,
270 pack,
271 g,
272 pad_value.unwrap_or(&Tensor::zero_scalar::<T>()?),
273 ),
274 }
275 }
276
277 #[inline(never)]
278 fn generic<'p, T: Copy + Datum>(
279 geometry: &'p ConcreteGeometry,
280 input: &TensorView,
281 pack: &'p mut TensorView,
282 g: usize,
283 pad_value: &Tensor,
284 ) -> TractResult<()> {
285 unsafe {
286 let pad_value = *pad_value.to_scalar_unchecked();
287 let mut mega_matrix = Tensor::uninitialized::<T>(&[geometry.k, geometry.n])?;
288 let mut mega_matrix_view = mega_matrix.to_array_view_mut_unchecked::<T>();
289 let ptr = input.as_ptr_unchecked::<T>();
290 let ptr = ptr.add(geometry.input_shape_with_n.c_stride() * (g * geometry.ci_per_group));
291 for (spatial, mut col) in ndarray::indices(&*geometry.pool.patch.output_shape)
292 .into_iter()
293 .zip(mega_matrix_view.axis_iter_mut(Axis(1)))
294 {
295 let mut col = col.iter_mut();
296 for ci in 0..geometry.ci_per_group {
297 let ptr = ptr.add(geometry.input_shape_with_n.c_stride() * ci);
298 for v in geometry.pool.patch.at(spatial.slice()) {
299 *col.next().expect("geometry error in conv") =
300 v.map(|o| *ptr.offset(o)).unwrap_or(pad_value);
301 }
302 }
303 }
304 geometry.b_pack.pack(pack, mega_matrix.view(), 0, 1);
305 Ok(())
306 }
307 }
308
309 #[inline(never)]
310 fn valid_1d<'p, T: Copy + Datum>(
311 geometry: &'p ConcreteGeometry,
312 input: &TensorView,
313 pack: &'p mut TensorView,
314 g: usize,
315 ) -> TractResult<()> {
316 unsafe {
317 let x_stride = *geometry.input_shape_with_n.h_stride() as isize
318 * geometry.pool.patch.spec.strides[0] as isize;
319 let c_stride = *geometry.input_shape_with_n.c_stride() as isize;
320 let pack = pack.as_slice_mut_unchecked::<T>();
321 let mut writer =
322 geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n);
323 let iptr = input.as_ptr_unchecked::<T>();
324 let iptr = iptr.add(g * geometry.ci_per_group * geometry.input_shape_with_n.c_stride());
325 for ci in 0..geometry.ci_per_group {
326 let iptr = iptr.offset(ci as isize * c_stride);
327 for koffset in &geometry.pool.patch.standard_layout_data_field {
328 let iptr = iptr.offset(*koffset);
329 for x in 0..*geometry.pool.patch.output_shape.get_unchecked(0) {
330 writer.write(*iptr.offset(x as isize * x_stride));
331 }
332 }
333 }
334 Ok(())
335 }
336 }
337
338 #[inline(never)]
339 fn padded_2d<'p, T: Copy + Datum>(
340 geometry: &'p ConcreteGeometry,
341 input: &TensorView,
342 pack: &'p mut TensorView,
343 g: usize,
344 pad_value: &Tensor,
345 ) -> TractResult<()> {
346 unsafe {
347 let pad_value = *pad_value.to_scalar_unchecked();
348 let pack = pack.as_slice_mut_unchecked::<T>();
349 let y_stride = geometry.pool.patch.spec.strides[0] as isize;
350 let x_stride = geometry.pool.patch.spec.strides[1] as isize;
351 let shape = &geometry.input_shape_with_n;
352 let y_stride_ptr = y_stride * *shape.h_stride() as isize;
353 let x_stride_ptr = x_stride * *shape.w_stride() as isize;
354 let c_stride_ptr = *shape.c_stride() as isize;
355 let input_heigth = shape.hw_dims()[0] as isize;
356 let input_width = shape.hw_dims()[1] as isize;
357 let kernel_len = geometry.pool.patch.standard_layout_data_field.len();
358 let mut writer =
359 geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n);
360 let iptr = input.as_ptr_unchecked::<T>();
361 let iptr = iptr.add(g * geometry.ci_per_group * shape.c_stride());
362 let output_width = *geometry.pool.patch.output_shape.get_unchecked(1);
363 for ci in 0..geometry.ci_per_group {
364 let iptr = iptr.offset(ci as isize * c_stride_ptr);
365 for kitem in 0..kernel_len {
366 let dy = *geometry.pool.patch.data_field.as_ptr().offset(kitem as isize * 2);
367 let dx =
368 *geometry.pool.patch.data_field.as_ptr().offset(1 + kitem as isize * 2);
369 let valid_x_start =
370 Integer::div_ceil(&-dx, &x_stride).max(0).min(output_width as _);
371 let valid_x_end = Integer::div_ceil(&(input_width - dx), &x_stride)
372 .max(0)
373 .min(output_width as _);
374
375 let iptr = iptr.offset(
376 *geometry.pool.patch.standard_layout_data_field.get_unchecked(kitem),
377 );
378 for yo in 0..*geometry.pool.patch.output_shape.get_unchecked(0) {
379 let y = yo as isize * y_stride + dy;
380 let iptr = iptr.offset(yo as isize * y_stride_ptr);
381 if y >= 0 && y < input_heigth {
382 Self::padded_2d_invalid_x_loop(
383 valid_x_start as usize,
384 pad_value,
385 &mut writer,
386 );
387 Self::padded_2d_valid_x_loop(
388 valid_x_start,
389 valid_x_end,
390 x_stride_ptr,
391 iptr,
392 &mut writer,
393 );
394 Self::padded_2d_invalid_x_loop(
395 output_width - valid_x_end as usize,
396 pad_value,
397 &mut writer,
398 );
399 } else {
400 Self::padded_2d_invalid_x_loop(output_width, pad_value, &mut writer);
401 }
402 }
403 }
404 }
405 }
406 Ok(())
407 }
408
409 #[inline(never)]
410 unsafe fn padded_2d_invalid_x_loop<T: Copy + Datum>(
411 count: usize,
412 pad_value: T,
413 writer: &mut tract_linalg::pack::KOutWriter<T>,
414 ) {
415 for _ in 0..count {
416 writer.write(pad_value);
417 }
418 }
419
420 #[inline(never)]
421 unsafe fn padded_2d_valid_x_loop<T: Copy + Datum>(
422 x_min: isize,
423 x_max: isize,
424 x_stride_ptr: isize,
425 iptr: *const T,
426 writer: &mut tract_linalg::pack::KOutWriter<T>,
427 ) {
428 for x in x_min..x_max {
429 writer.write(unsafe { *iptr.offset(x * x_stride_ptr) });
430 }
431 }
432
433 #[inline(never)]
434 fn valid_2d<'p, T: Copy + Datum>(
435 geometry: &'p ConcreteGeometry,
436 input: &TensorView,
437 pack: &'p mut TensorView,
438 g: usize,
439 ) -> TractResult<()> {
440 unsafe {
441 let pack = pack.as_slice_mut_unchecked::<T>();
442 let shape = &geometry.input_shape_with_n;
443 let y_stride = geometry.pool.patch.spec.strides[0] as isize;
444 let x_stride = geometry.pool.patch.spec.strides[1] as isize;
445 let y_stride_ptr = y_stride * *shape.h_stride() as isize;
446 let x_stride_ptr = x_stride * *shape.w_stride() as isize;
447 let c_stride_ptr = *shape.c_stride() as isize;
448 let mut writer =
449 geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n);
450 let iptr = input.as_ptr_unchecked::<T>();
451 let iptr = iptr.add(g * geometry.ci_per_group * shape.c_stride());
452 for ci in 0..geometry.ci_per_group {
453 let iptr = iptr.offset(ci as isize * c_stride_ptr);
454 for koffset in &geometry.pool.patch.standard_layout_data_field {
455 let iptr = iptr.offset(*koffset);
456 for y in 0..*geometry.pool.patch.output_shape.get_unchecked(0) {
457 let iptr = iptr.offset(y as isize * y_stride_ptr);
458 for x in 0..*geometry.pool.patch.output_shape.get_unchecked(1) {
459 writer.write(*iptr.offset(x as isize * x_stride_ptr));
460 }
461 }
462 }
463 }
464 Ok(())
465 }
466 }
467}