1mod block_quant;
2mod blocked;
3#[allow(clippy::module_inception)]
4mod conv;
5mod depth_wise;
6mod im2col;
7mod lazy_im2col;
8mod q_sum_b;
9
10use crate::internal::*;
11use crate::ops::cnn::Deconv;
12
13pub use self::blocked::BlockedConv;
14pub use self::conv::Conv;
15pub use self::im2col::Im2Col;
16pub(crate) use self::q_sum_b::QSumB;
17
18#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
19pub enum KernelFormat {
20 #[default]
21 OIHW,
22 HWIO,
23 OHWI,
24}
25
26impl KernelFormat {
27 pub fn h_axis(&self) -> usize {
28 match self {
29 KernelFormat::OIHW => 2,
30 KernelFormat::HWIO => 0,
31 KernelFormat::OHWI => 1,
32 }
33 }
34
35 pub fn spatial_shape<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
36 &full_shape[self.h_axis()..][..full_shape.len() - 2]
37 }
38
39 pub fn hw<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
40 self.spatial_shape(full_shape)
41 }
42
43 pub fn i<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
44 match self {
45 KernelFormat::OIHW => &full_shape[1],
46 KernelFormat::HWIO => &full_shape[full_shape.len() - 2],
47 KernelFormat::OHWI => &full_shape[full_shape.len() - 1],
48 }
49 }
50
51 pub fn o_axis<D>(&self, full_shape: &[D]) -> usize {
52 match self {
53 KernelFormat::OIHW | KernelFormat::OHWI => 0,
54 KernelFormat::HWIO => full_shape.len() - 1,
55 }
56 }
57
58 pub fn i_axis<D>(&self, full_shape: &[D]) -> usize {
59 match self {
60 KernelFormat::OIHW => 1,
61 KernelFormat::OHWI => full_shape.len() - 1,
62 KernelFormat::HWIO => full_shape.len() - 2,
63 }
64 }
65
66 pub fn o<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
67 &full_shape[self.o_axis(full_shape)]
68 }
69
70 pub fn input_channels<'s, D: DimLike>(
71 &self,
72 full_kernel_shape: &'s [D],
73 group: usize,
74 ) -> Cow<'s, D> {
75 match self {
76 KernelFormat::OIHW => Cow::Owned(self.i(full_kernel_shape).clone() * group),
77 KernelFormat::HWIO | KernelFormat::OHWI => Cow::Borrowed(self.i(full_kernel_shape)),
78 }
79 }
80
81 pub fn output_channels<'s, D: DimLike>(
82 &self,
83 full_kernel_shape: &'s [D],
84 group: usize,
85 ) -> Cow<'s, D> {
86 match self {
87 KernelFormat::OIHW => Cow::Borrowed(self.o(full_kernel_shape)),
88 KernelFormat::HWIO | KernelFormat::OHWI => {
89 Cow::Owned(self.o(full_kernel_shape).clone() * group)
90 }
91 }
92 }
93
94 pub fn kernel_as_group_o_i_h_w_ops(
95 &self,
96 full_shape: &[impl DimLike],
97 group: usize,
98 ) -> TVec<AxisOp> {
99 let geo_rank = full_shape.len() - 2;
100 match self {
101 KernelFormat::HWIO => {
103 tvec!(
104 AxisOp::Reshape(
105 geo_rank,
106 tvec!(self.i(full_shape).to_dim()),
107 tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
108 ), AxisOp::Move(geo_rank, 0), AxisOp::Move(geo_rank + 2, 1), AxisOp::Move(geo_rank + 2, 2)
112 ) }
114 KernelFormat::OIHW => {
116 tvec!(AxisOp::Reshape(
117 0,
118 tvec!(self.o(full_shape).to_dim()),
119 tvec!(group.to_dim(), self.o(full_shape).to_dim() / group),
120 ))
121 }
122 KernelFormat::OHWI => {
124 tvec!(
125 AxisOp::Reshape(
126 geo_rank + 1,
127 tvec!(self.i(full_shape).to_dim()),
128 tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
129 ), AxisOp::Move(geo_rank + 1, 0), AxisOp::Move(geo_rank + 2, 2)
132 )
133 }
134 }
135 }
136
137 pub fn kernel_as_group_o_i_hw_ops(
138 &self,
139 full_shape: &[impl DimLike],
140 group: usize,
141 ) -> TVec<AxisOp> {
142 let mut ops = self.kernel_as_group_o_i_h_w_ops(full_shape, group);
143 if self.hw(full_shape).len() > 1 {
144 ops.push(AxisOp::Reshape(
145 3,
146 self.hw(full_shape).iter().map(|t| t.to_dim()).collect(),
147 tvec!(self.hw(full_shape).iter().map(|t| t.to_dim()).product()),
148 ));
149 }
150 ops
151 }
152
153 pub fn kernel_as_group_o_ihw_ops(
154 &self,
155 full_shape: &[impl DimLike],
156 group: usize,
157 ) -> TVec<AxisOp> {
158 let i = (self.input_channels(full_shape, group).into_owned() / group).to_dim();
159 let hw = self.hw(full_shape).iter().map(|t| t.to_dim()).product::<TDim>();
160 let mut ops = self.kernel_as_group_o_i_hw_ops(full_shape, group);
161 ops.push(AxisOp::Reshape(2, tvec!(i.clone(), hw.clone()), tvec!(i * hw)));
162 ops
163 }
164
165 pub fn kernel_as_group_o_i_hw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
166 let mut kernel = kernel.clone();
167 let ops = self.kernel_as_group_o_i_hw_ops(kernel.shape(), group);
168 for op in &ops {
169 op.change_tensor(&mut kernel, false)?;
170 }
171 Ok(kernel)
172 }
173
174 pub fn kernel_as_group_o_ihw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
175 let group_o_i_hw = self.kernel_as_group_o_i_hw(kernel, group)?;
176 Ok(group_o_i_hw.collapse_axis_with_next(2))
177 }
178}
179
180pub fn rewrite_kernel_conv_in_oihw(
181 _ctx: &(),
182 model: &TypedModel,
183 node: &TypedNode,
184 name: &str,
185 conv: &Conv,
186) -> TractResult<Option<TypedModelPatch>> {
187 rewrite_kernel_in_oihw(
188 model,
189 node,
190 name,
191 conv.kernel_fmt,
192 conv.group,
193 Box::new(Conv { kernel_fmt: KernelFormat::OIHW, ..conv.clone() }),
194 )
195}
196
197pub fn rewrite_kernel_deconv_in_oihw(
198 _ctx: &(),
199 model: &TypedModel,
200 node: &TypedNode,
201 name: &str,
202 conv: &Deconv,
203) -> TractResult<Option<TypedModelPatch>> {
204 rewrite_kernel_in_oihw(
205 model,
206 node,
207 name,
208 conv.kernel_format,
209 conv.group,
210 Box::new(Deconv { kernel_format: KernelFormat::OIHW, ..conv.clone() }),
211 )
212}
213
214fn rewrite_kernel_in_oihw(
215 model: &TypedModel,
216 node: &TypedNode,
217 name: &str,
218 fmt: KernelFormat,
219 group: usize,
220 new: Box<dyn TypedOp>,
221) -> TractResult<Option<TypedModelPatch>> {
222 rule_if!(fmt != KernelFormat::OIHW);
223 let mut patch = TypedModelPatch::default();
224 let mut wire = patch.taps(model, &node.inputs)?;
225 let prefix = format!("{name}.kernel_reorg");
226 for (ix, op) in fmt
227 .kernel_as_group_o_i_h_w_ops(&patch.outlet_fact(wire[1])?.shape, group)
228 .into_iter()
229 .enumerate()
230 {
231 wire[1] = patch.wire_node(format!("{prefix}.{ix}"), op, &[wire[1]])?[0];
232 }
233 wire[1] =
234 AxisOp::wire_collapse_axis(&mut patch, format!("{name}.kernel_reorg_go"), wire[1], 0)?[0];
235 wire = patch.wire_node(name, new, &wire)?;
236 patch.shunt_outside(model, node.id.into(), wire[0])?;
237 Ok(Some(patch))
238}