1mod block_quant;
2#[allow(clippy::module_inception)]
3mod conv;
4mod depth_wise;
5mod im2col;
6mod lazy_im2col;
7mod q_sum_b;
8
9use crate::internal::*;
10use crate::ops::cnn::Deconv;
11
12pub use self::conv::Conv;
13pub use self::im2col::Im2Col;
14pub(crate) use self::q_sum_b::QSumB;
15
16#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
17pub enum KernelFormat {
18 #[default]
19 OIHW,
20 HWIO,
21 OHWI,
22}
23
24impl KernelFormat {
25 pub fn h_axis(&self) -> usize {
26 match self {
27 KernelFormat::OIHW => 2,
28 KernelFormat::HWIO => 0,
29 KernelFormat::OHWI => 1,
30 }
31 }
32
33 pub fn spatial_shape<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
34 &full_shape[self.h_axis()..][..full_shape.len() - 2]
35 }
36
37 pub fn hw<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
38 self.spatial_shape(full_shape)
39 }
40
41 pub fn i<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
42 match self {
43 KernelFormat::OIHW => &full_shape[1],
44 KernelFormat::HWIO => &full_shape[full_shape.len() - 2],
45 KernelFormat::OHWI => &full_shape[full_shape.len() - 1],
46 }
47 }
48
49 pub fn o_axis<D>(&self, full_shape: &[D]) -> usize {
50 match self {
51 KernelFormat::OIHW | KernelFormat::OHWI => 0,
52 KernelFormat::HWIO => full_shape.len() - 1,
53 }
54 }
55
56 pub fn i_axis<D>(&self, full_shape: &[D]) -> usize {
57 match self {
58 KernelFormat::OIHW => 1,
59 KernelFormat::OHWI => full_shape.len() - 1,
60 KernelFormat::HWIO => full_shape.len() - 2,
61 }
62 }
63
64 pub fn o<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
65 &full_shape[self.o_axis(full_shape)]
66 }
67
68 pub fn input_channels<'s, D: DimLike>(
69 &self,
70 full_kernel_shape: &'s [D],
71 group: usize,
72 ) -> Cow<'s, D> {
73 match self {
74 KernelFormat::OIHW => Cow::Owned(self.i(full_kernel_shape).clone() * group),
75 KernelFormat::HWIO | KernelFormat::OHWI => Cow::Borrowed(self.i(full_kernel_shape)),
76 }
77 }
78
79 pub fn output_channels<'s, D: DimLike>(
80 &self,
81 full_kernel_shape: &'s [D],
82 group: usize,
83 ) -> Cow<'s, D> {
84 match self {
85 KernelFormat::OIHW => Cow::Borrowed(self.o(full_kernel_shape)),
86 KernelFormat::HWIO | KernelFormat::OHWI => {
87 Cow::Owned(self.o(full_kernel_shape).clone() * group)
88 }
89 }
90 }
91
92 pub fn kernel_as_group_o_i_h_w_ops(
93 &self,
94 full_shape: &[impl DimLike],
95 group: usize,
96 ) -> TVec<AxisOp> {
97 let geo_rank = full_shape.len() - 2;
98 match self {
99 KernelFormat::HWIO => {
101 tvec!(
102 AxisOp::Reshape(
103 geo_rank,
104 tvec!(self.i(full_shape).to_dim()),
105 tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
106 ), AxisOp::Move(geo_rank, 0), AxisOp::Move(geo_rank + 2, 1), AxisOp::Move(geo_rank + 2, 2)
110 ) }
112 KernelFormat::OIHW => {
114 tvec!(AxisOp::Reshape(
115 0,
116 tvec!(self.o(full_shape).to_dim()),
117 tvec!(group.to_dim(), self.o(full_shape).to_dim() / group),
118 ))
119 }
120 KernelFormat::OHWI => {
122 tvec!(
123 AxisOp::Reshape(
124 geo_rank + 1,
125 tvec!(self.i(full_shape).to_dim()),
126 tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
127 ), AxisOp::Move(geo_rank + 1, 0), AxisOp::Move(geo_rank + 2, 2)
130 )
131 }
132 }
133 }
134
135 pub fn kernel_as_group_o_i_hw_ops(
136 &self,
137 full_shape: &[impl DimLike],
138 group: usize,
139 ) -> TVec<AxisOp> {
140 let mut ops = self.kernel_as_group_o_i_h_w_ops(full_shape, group);
141 if self.hw(full_shape).len() > 1 {
142 ops.push(AxisOp::Reshape(
143 3,
144 self.hw(full_shape).iter().map(|t| t.to_dim()).collect(),
145 tvec!(self.hw(full_shape).iter().map(|t| t.to_dim()).product()),
146 ));
147 }
148 ops
149 }
150
151 pub fn kernel_as_group_o_ihw_ops(
152 &self,
153 full_shape: &[impl DimLike],
154 group: usize,
155 ) -> TVec<AxisOp> {
156 let i = (self.input_channels(full_shape, group).into_owned() / group).to_dim();
157 let hw = self.hw(full_shape).iter().map(|t| t.to_dim()).product::<TDim>();
158 let mut ops = self.kernel_as_group_o_i_hw_ops(full_shape, group);
159 ops.push(AxisOp::Reshape(2, tvec!(i.clone(), hw.clone()), tvec!(i * hw)));
160 ops
161 }
162
163 pub fn kernel_as_group_o_i_hw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
164 let mut kernel = kernel.clone();
165 let ops = self.kernel_as_group_o_i_hw_ops(kernel.shape(), group);
166 for op in &ops {
167 op.change_tensor(&mut kernel, false)?;
168 }
169 Ok(kernel)
170 }
171
172 pub fn kernel_as_group_o_ihw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
173 let group_o_i_hw = self.kernel_as_group_o_i_hw(kernel, group)?;
174 Ok(group_o_i_hw.collapse_axis_with_next(2))
175 }
176}
177
178pub fn rewrite_kernel_conv_in_oihw(
179 _ctx: &(),
180 model: &TypedModel,
181 node: &TypedNode,
182 name: &str,
183 conv: &Conv,
184) -> TractResult<Option<TypedModelPatch>> {
185 rewrite_kernel_in_oihw(
186 model,
187 node,
188 name,
189 conv.kernel_fmt,
190 conv.group,
191 Box::new(Conv { kernel_fmt: KernelFormat::OIHW, ..conv.clone() }),
192 )
193}
194
195pub fn rewrite_kernel_deconv_in_oihw(
196 _ctx: &(),
197 model: &TypedModel,
198 node: &TypedNode,
199 name: &str,
200 conv: &Deconv,
201) -> TractResult<Option<TypedModelPatch>> {
202 rewrite_kernel_in_oihw(
203 model,
204 node,
205 name,
206 conv.kernel_format,
207 conv.group,
208 Box::new(Deconv { kernel_format: KernelFormat::OIHW, ..conv.clone() }),
209 )
210}
211
212fn rewrite_kernel_in_oihw(
213 model: &TypedModel,
214 node: &TypedNode,
215 name: &str,
216 fmt: KernelFormat,
217 group: usize,
218 new: Box<dyn TypedOp>,
219) -> TractResult<Option<TypedModelPatch>> {
220 rule_if!(fmt != KernelFormat::OIHW);
221 let mut patch = TypedModelPatch::default();
222 let mut wire = patch.taps(model, &node.inputs)?;
223 let prefix = format!("{name}.kernel_reorg");
224 for (ix, op) in fmt
225 .kernel_as_group_o_i_h_w_ops(&patch.outlet_fact(wire[1])?.shape, group)
226 .into_iter()
227 .enumerate()
228 {
229 wire[1] = patch.wire_node(format!("{prefix}.{ix}"), op, &[wire[1]])?[0];
230 }
231 wire[1] =
232 AxisOp::wire_collapse_axis(&mut patch, format!("{name}.kernel_reorg_go"), wire[1], 0)?[0];
233 wire = patch.wire_node(name, new, &wire)?;
234 patch.shunt_outside(model, node.id.into(), wire[0])?;
235 Ok(Some(patch))
236}