tract_core/ops/cnn/conv/
mod.rs

1#[allow(clippy::module_inception)]
2mod conv;
3mod depth_wise;
4mod im2col;
5mod lazy_im2col;
6mod q_sum_b;
7
8use crate::internal::*;
9
10pub use self::im2col::Im2Col;
11pub(crate) use self::q_sum_b::QSumB;
12pub use self::conv::Conv;
13
14#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
15pub enum KernelFormat {
16    #[default]
17    OIHW,
18    HWIO,
19    OHWI,
20}
21
22impl KernelFormat {
23    pub fn h_axis(&self) -> usize {
24        match self {
25            KernelFormat::OIHW => 2,
26            KernelFormat::HWIO => 0,
27            KernelFormat::OHWI => 1,
28        }
29    }
30
31    pub fn spatial_shape<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
32        &full_shape[self.h_axis()..][..full_shape.len() - 2]
33    }
34
35    pub fn hw<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
36        self.spatial_shape(full_shape)
37    }
38
39    pub fn i<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
40        match self {
41            KernelFormat::OIHW => &full_shape[1],
42            KernelFormat::HWIO => &full_shape[full_shape.len() - 2],
43            KernelFormat::OHWI => &full_shape[full_shape.len() - 1],
44        }
45    }
46
47    pub fn o_axis<D>(&self, full_shape: &[D]) -> usize {
48        match self {
49            KernelFormat::OIHW | KernelFormat::OHWI => 0,
50            KernelFormat::HWIO => full_shape.len() - 1,
51        }
52    }
53
54    pub fn o<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
55        &full_shape[self.o_axis(full_shape)]
56    }
57
58    pub fn input_channels<'s, D: DimLike>(
59        &self,
60        full_kernel_shape: &'s [D],
61        group: usize,
62    ) -> Cow<'s, D> {
63        match self {
64            KernelFormat::OIHW => Cow::Owned(self.i(full_kernel_shape).clone() * group),
65            KernelFormat::HWIO | KernelFormat::OHWI => Cow::Borrowed(self.i(full_kernel_shape)),
66        }
67    }
68
69    pub fn output_channels<'s, D: DimLike>(
70        &self,
71        full_kernel_shape: &'s [D],
72        group: usize,
73    ) -> Cow<'s, D> {
74        match self {
75            KernelFormat::OIHW => Cow::Borrowed(self.o(full_kernel_shape)),
76            KernelFormat::HWIO | KernelFormat::OHWI => {
77                Cow::Owned(self.o(full_kernel_shape).clone() * group)
78            }
79        }
80    }
81
82    pub fn kernel_as_group_o_i_h_w_ops(
83        &self,
84        full_shape: &[impl DimLike],
85        group: usize,
86    ) -> TVec<AxisOp> {
87        let geo_rank = full_shape.len() - 2;
88        match self {
89            // g is on i
90            KernelFormat::HWIO => {
91                tvec!(
92                    AxisOp::Reshape(
93                        geo_rank,
94                        tvec!(self.i(full_shape).to_dim()),
95                        tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
96                    ), // h w g i o
97                    AxisOp::Move(geo_rank, 0),     // g h w i o
98                    AxisOp::Move(geo_rank + 2, 1), // g o h w i
99                    AxisOp::Move(geo_rank + 2, 2)
100                ) // g o i h w
101            }
102            // g is on o
103            KernelFormat::OIHW => {
104                tvec!(AxisOp::Reshape(
105                    0,
106                    tvec!(self.o(full_shape).to_dim()),
107                    tvec!(group.to_dim(), self.o(full_shape).to_dim() / group),
108                ))
109            }
110            // g is on i
111            KernelFormat::OHWI => {
112                tvec!(
113                    AxisOp::Reshape(
114                        geo_rank + 1,
115                        tvec!(self.i(full_shape).to_dim()),
116                        tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
117                    ), // o h w g i
118                    AxisOp::Move(geo_rank + 1, 0), // g o h w i
119                    AxisOp::Move(geo_rank + 2, 2)
120                )
121            }
122        }
123    }
124
125    pub fn kernel_as_group_o_i_hw_ops(
126        &self,
127        full_shape: &[impl DimLike],
128        group: usize,
129    ) -> TVec<AxisOp> {
130        let mut ops = self.kernel_as_group_o_i_h_w_ops(full_shape, group);
131        if self.hw(full_shape).len() > 1 {
132            ops.push(AxisOp::Reshape(
133                3,
134                self.hw(full_shape).iter().map(|t| t.to_dim()).collect(),
135                tvec!(self.hw(full_shape).iter().map(|t| t.to_dim()).product()),
136            ));
137        }
138        ops
139    }
140
141    pub fn kernel_as_group_o_ihw_ops(
142        &self,
143        full_shape: &[impl DimLike],
144        group: usize,
145    ) -> TVec<AxisOp> {
146        let i = (self.input_channels(full_shape, group).into_owned() / group).to_dim();
147        let hw = self.hw(full_shape).iter().map(|t| t.to_dim()).product::<TDim>();
148        let mut ops = self.kernel_as_group_o_i_hw_ops(full_shape, group);
149        ops.push(AxisOp::Reshape(2, tvec!(i.clone(), hw.clone()), tvec!(i * hw)));
150        ops
151    }
152
153    pub fn kernel_as_group_o_i_hw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
154        let mut kernel = kernel.clone();
155        let ops = self.kernel_as_group_o_i_hw_ops(kernel.shape(), group);
156        for op in &ops {
157            op.change_tensor(&mut kernel, false)?;
158        }
159        Ok(kernel)
160    }
161
162    pub fn kernel_as_group_o_ihw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
163        let group_o_i_hw = self.kernel_as_group_o_i_hw(kernel, group)?;
164        Ok(group_o_i_hw.collapse_axis_with_next(2))
165    }
166}