1use crate::{
2 device::Device,
3 dim::DimDyn,
4 matrix::{Matrix, Owned, Ref},
5 num::Num,
6 slice_dynamic,
7};
8
9#[expect(clippy::needless_pass_by_value)]
10pub(super) fn col2im<T: Num, D: Device>(
11 col: Matrix<Ref<&T>, DimDyn, D>,
12 img_shape: [usize; 4],
13 kernel_size: (usize, usize),
14 stride: (usize, usize),
15 pad: (usize, usize),
16) -> Matrix<Owned<T>, DimDyn, D> {
17 let (batch_size, c, h, w) = (img_shape[0], img_shape[1], img_shape[2], img_shape[3]);
18 let (kh, kw) = kernel_size;
19 let (sh, sw) = stride;
20 let (ph, pw) = pad;
21 let (oh, ow) = ((h + 2 * ph - kh) / sh + 1, (w + 2 * pw - kw) / sw + 1);
22
23 let mut img =
24 Matrix::<_, DimDyn, _>::zeros([batch_size, c, h + 2 * ph + sh - 1, w + 2 * pw + sw - 1]);
25
26 for j in 0..kh {
27 let j_lim = j + sh * oh;
28 for i in 0..kw {
29 let i_lim = i + sw * ow;
30 let col_ref = col.to_ref();
31 let col_ref = col_ref.slice_dyn(slice_dynamic!(.., .., j, i, .., ..));
32
33 let mut img_slice = img.to_ref_mut().slice_mut_dyn(slice_dynamic!(
34 ..,
35 ..,
36 j..j_lim;sh,
37 i..i_lim;sw
38 ));
39 img_slice += col_ref;
40 }
41 }
42
43 let img = img.slice_dyn(slice_dynamic!(.., .., ph..ph + h, pw..pw + w));
44 img.new_matrix()
45}
46
47#[cfg(test)]
48mod col2im {
49 use crate::{
50 device::cpu::Cpu,
51 dim::DimDyn,
52 matrix::{Matrix, Owned},
53 };
54
55 use super::col2im;
56
57 #[expect(clippy::cast_precision_loss)]
58 #[test]
59 fn col2im_small() {
60 let col = (1..=1350).map(|x| x as f32).collect::<Vec<f32>>();
61 let col = Matrix::<Owned<f32>, DimDyn, Cpu>::from_vec(col, [2, 3, 3, 3, 5, 5]);
62 let img_shape = [2, 3, 5, 5];
63 let kernel_shape = (3, 3);
64 let stride = (1, 1);
65 let pad = (1, 1);
66 let img = col2im(col.to_ref(), img_shape, kernel_shape, stride, pad);
67 let ans = vec![
68 216, 402, 408, 414, 328, 564, 963, 972, 981, 732, 594, 1008, 1017, 1026, 762, 624,
69 1053, 1062, 1071, 792, 576, 942, 948, 954, 688, 1116, 1752, 1758, 1764, 1228, 1914,
70 2988, 2997, 3006, 2082, 1944, 3033, 3042, 3051, 2112, 1974, 3078, 3087, 3096, 2142,
71 1476, 2292, 2298, 2304, 1588, 2016, 3102, 3108, 3114, 2128, 3264, 5013, 5022, 5031,
72 3432, 3294, 5058, 5067, 5076, 3462, 3324, 5103, 5112, 5121, 3492, 2376, 3642, 3648,
73 3654, 2488, 2916, 4452, 4458, 4464, 3028, 4614, 7038, 7047, 7056, 4782, 4644, 7083,
74 7092, 7101, 4812, 4674, 7128, 7137, 7146, 4842, 3276, 4992, 4998, 5004, 3388, 3816,
75 5802, 5808, 5814, 3928, 5964, 9063, 9072, 9081, 6132, 5994, 9108, 9117, 9126, 6162,
76 6024, 9153, 9162, 9171, 6192, 4176, 6342, 6348, 6354, 4288, 4716, 7152, 7158, 7164,
77 4828, 7314, 11088, 11097, 11106, 7482, 7344, 11133, 11142, 11151, 7512, 7374, 11178,
78 11187, 11196, 7542, 5076, 7692, 7698, 7704, 5188,
79 ]
80 .iter()
81 .map(|&x| x as f32)
82 .collect::<Vec<f32>>();
83 let ans = Matrix::<Owned<f32>, DimDyn, Cpu>::from_vec(ans, [2, 3, 5, 5]);
84 assert!((img - ans).asum() < 1e-6);
85 }
86}