zenu_matrix/nn/
col2im.rs

1use crate::{
2    device::Device,
3    dim::DimDyn,
4    matrix::{Matrix, Owned, Ref},
5    num::Num,
6    slice_dynamic,
7};
8
9#[expect(clippy::needless_pass_by_value)]
10pub(super) fn col2im<T: Num, D: Device>(
11    col: Matrix<Ref<&T>, DimDyn, D>,
12    img_shape: [usize; 4],
13    kernel_size: (usize, usize),
14    stride: (usize, usize),
15    pad: (usize, usize),
16) -> Matrix<Owned<T>, DimDyn, D> {
17    let (batch_size, c, h, w) = (img_shape[0], img_shape[1], img_shape[2], img_shape[3]);
18    let (kh, kw) = kernel_size;
19    let (sh, sw) = stride;
20    let (ph, pw) = pad;
21    let (oh, ow) = ((h + 2 * ph - kh) / sh + 1, (w + 2 * pw - kw) / sw + 1);
22
23    let mut img =
24        Matrix::<_, DimDyn, _>::zeros([batch_size, c, h + 2 * ph + sh - 1, w + 2 * pw + sw - 1]);
25
26    for j in 0..kh {
27        let j_lim = j + sh * oh;
28        for i in 0..kw {
29            let i_lim = i + sw * ow;
30            let col_ref = col.to_ref();
31            let col_ref = col_ref.slice_dyn(slice_dynamic!(.., .., j, i, .., ..));
32
33            let mut img_slice = img.to_ref_mut().slice_mut_dyn(slice_dynamic!(
34                ..,
35                ..,
36                j..j_lim;sh,
37                i..i_lim;sw
38            ));
39            img_slice += col_ref;
40        }
41    }
42
43    let img = img.slice_dyn(slice_dynamic!(.., .., ph..ph + h, pw..pw + w));
44    img.new_matrix()
45}
46
47#[cfg(test)]
48mod col2im {
49    use crate::{
50        device::cpu::Cpu,
51        dim::DimDyn,
52        matrix::{Matrix, Owned},
53    };
54
55    use super::col2im;
56
57    #[expect(clippy::cast_precision_loss)]
58    #[test]
59    fn col2im_small() {
60        let col = (1..=1350).map(|x| x as f32).collect::<Vec<f32>>();
61        let col = Matrix::<Owned<f32>, DimDyn, Cpu>::from_vec(col, [2, 3, 3, 3, 5, 5]);
62        let img_shape = [2, 3, 5, 5];
63        let kernel_shape = (3, 3);
64        let stride = (1, 1);
65        let pad = (1, 1);
66        let img = col2im(col.to_ref(), img_shape, kernel_shape, stride, pad);
67        let ans = vec![
68            216, 402, 408, 414, 328, 564, 963, 972, 981, 732, 594, 1008, 1017, 1026, 762, 624,
69            1053, 1062, 1071, 792, 576, 942, 948, 954, 688, 1116, 1752, 1758, 1764, 1228, 1914,
70            2988, 2997, 3006, 2082, 1944, 3033, 3042, 3051, 2112, 1974, 3078, 3087, 3096, 2142,
71            1476, 2292, 2298, 2304, 1588, 2016, 3102, 3108, 3114, 2128, 3264, 5013, 5022, 5031,
72            3432, 3294, 5058, 5067, 5076, 3462, 3324, 5103, 5112, 5121, 3492, 2376, 3642, 3648,
73            3654, 2488, 2916, 4452, 4458, 4464, 3028, 4614, 7038, 7047, 7056, 4782, 4644, 7083,
74            7092, 7101, 4812, 4674, 7128, 7137, 7146, 4842, 3276, 4992, 4998, 5004, 3388, 3816,
75            5802, 5808, 5814, 3928, 5964, 9063, 9072, 9081, 6132, 5994, 9108, 9117, 9126, 6162,
76            6024, 9153, 9162, 9171, 6192, 4176, 6342, 6348, 6354, 4288, 4716, 7152, 7158, 7164,
77            4828, 7314, 11088, 11097, 11106, 7482, 7344, 11133, 11142, 11151, 7512, 7374, 11178,
78            11187, 11196, 7542, 5076, 7692, 7698, 7704, 5188,
79        ]
80        .iter()
81        .map(|&x| x as f32)
82        .collect::<Vec<f32>>();
83        let ans = Matrix::<Owned<f32>, DimDyn, Cpu>::from_vec(ans, [2, 3, 5, 5]);
84        assert!((img - ans).asum() < 1e-6);
85    }
86}