1#[allow(clippy::module_inception)]
2mod conv;
3mod depth_wise;
4mod im2col;
5mod lazy_im2col;
6mod q_sum_b;
7
8use crate::internal::*;
9
10pub use self::im2col::Im2Col;
11pub(crate) use self::q_sum_b::QSumB;
12pub use self::conv::Conv;
13
14#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
15pub enum KernelFormat {
16 #[default]
17 OIHW,
18 HWIO,
19 OHWI,
20}
21
22impl KernelFormat {
23 pub fn h_axis(&self) -> usize {
24 match self {
25 KernelFormat::OIHW => 2,
26 KernelFormat::HWIO => 0,
27 KernelFormat::OHWI => 1,
28 }
29 }
30
31 pub fn spatial_shape<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
32 &full_shape[self.h_axis()..][..full_shape.len() - 2]
33 }
34
35 pub fn hw<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
36 self.spatial_shape(full_shape)
37 }
38
39 pub fn i<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
40 match self {
41 KernelFormat::OIHW => &full_shape[1],
42 KernelFormat::HWIO => &full_shape[full_shape.len() - 2],
43 KernelFormat::OHWI => &full_shape[full_shape.len() - 1],
44 }
45 }
46
47 pub fn o_axis<D>(&self, full_shape: &[D]) -> usize {
48 match self {
49 KernelFormat::OIHW | KernelFormat::OHWI => 0,
50 KernelFormat::HWIO => full_shape.len() - 1,
51 }
52 }
53
54 pub fn o<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
55 &full_shape[self.o_axis(full_shape)]
56 }
57
58 pub fn input_channels<'s, D: DimLike>(
59 &self,
60 full_kernel_shape: &'s [D],
61 group: usize,
62 ) -> Cow<'s, D> {
63 match self {
64 KernelFormat::OIHW => Cow::Owned(self.i(full_kernel_shape).clone() * group),
65 KernelFormat::HWIO | KernelFormat::OHWI => Cow::Borrowed(self.i(full_kernel_shape)),
66 }
67 }
68
69 pub fn output_channels<'s, D: DimLike>(
70 &self,
71 full_kernel_shape: &'s [D],
72 group: usize,
73 ) -> Cow<'s, D> {
74 match self {
75 KernelFormat::OIHW => Cow::Borrowed(self.o(full_kernel_shape)),
76 KernelFormat::HWIO | KernelFormat::OHWI => {
77 Cow::Owned(self.o(full_kernel_shape).clone() * group)
78 }
79 }
80 }
81
82 pub fn kernel_as_group_o_i_h_w_ops(
83 &self,
84 full_shape: &[impl DimLike],
85 group: usize,
86 ) -> TVec<AxisOp> {
87 let geo_rank = full_shape.len() - 2;
88 match self {
89 KernelFormat::HWIO => {
91 tvec!(
92 AxisOp::Reshape(
93 geo_rank,
94 tvec!(self.i(full_shape).to_dim()),
95 tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
96 ), AxisOp::Move(geo_rank, 0), AxisOp::Move(geo_rank + 2, 1), AxisOp::Move(geo_rank + 2, 2)
100 ) }
102 KernelFormat::OIHW => {
104 tvec!(AxisOp::Reshape(
105 0,
106 tvec!(self.o(full_shape).to_dim()),
107 tvec!(group.to_dim(), self.o(full_shape).to_dim() / group),
108 ))
109 }
110 KernelFormat::OHWI => {
112 tvec!(
113 AxisOp::Reshape(
114 geo_rank + 1,
115 tvec!(self.i(full_shape).to_dim()),
116 tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
117 ), AxisOp::Move(geo_rank + 1, 0), AxisOp::Move(geo_rank + 2, 2)
120 )
121 }
122 }
123 }
124
125 pub fn kernel_as_group_o_i_hw_ops(
126 &self,
127 full_shape: &[impl DimLike],
128 group: usize,
129 ) -> TVec<AxisOp> {
130 let mut ops = self.kernel_as_group_o_i_h_w_ops(full_shape, group);
131 if self.hw(full_shape).len() > 1 {
132 ops.push(AxisOp::Reshape(
133 3,
134 self.hw(full_shape).iter().map(|t| t.to_dim()).collect(),
135 tvec!(self.hw(full_shape).iter().map(|t| t.to_dim()).product()),
136 ));
137 }
138 ops
139 }
140
141 pub fn kernel_as_group_o_ihw_ops(
142 &self,
143 full_shape: &[impl DimLike],
144 group: usize,
145 ) -> TVec<AxisOp> {
146 let i = (self.input_channels(full_shape, group).into_owned() / group).to_dim();
147 let hw = self.hw(full_shape).iter().map(|t| t.to_dim()).product::<TDim>();
148 let mut ops = self.kernel_as_group_o_i_hw_ops(full_shape, group);
149 ops.push(AxisOp::Reshape(2, tvec!(i.clone(), hw.clone()), tvec!(i * hw)));
150 ops
151 }
152
153 pub fn kernel_as_group_o_i_hw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
154 let mut kernel = kernel.clone();
155 let ops = self.kernel_as_group_o_i_hw_ops(kernel.shape(), group);
156 for op in &ops {
157 op.change_tensor(&mut kernel, false)?;
158 }
159 Ok(kernel)
160 }
161
162 pub fn kernel_as_group_o_ihw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
163 let group_o_i_hw = self.kernel_as_group_o_i_hw(kernel, group)?;
164 Ok(group_o_i_hw.collapse_axis_with_next(2))
165 }
166}