1mod block_quant;
2#[allow(clippy::module_inception)]
3mod conv;
4mod depth_wise;
5mod im2col;
6mod lazy_im2col;
7mod q_sum_b;
8
9use tract_linalg::block_quant::BlockQuantFact;
10
11use crate::internal::*;
12use crate::ops::cnn::Deconv;
13
14pub use self::conv::Conv;
15pub use self::im2col::Im2Col;
16pub(crate) use self::q_sum_b::QSumB;
17
18fn block_quant_aware_weight_shape(weights: &TypedFact) -> TractResult<Cow<'_, ShapeFact>> {
19 if weights.datum_type.is_number() {
20 Ok(Cow::Borrowed(&weights.shape))
21 } else if let Some(bqf) =
22 weights.opaque_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
23 {
24 Ok(Cow::Owned(bqf.shape().into()))
25 } else {
26 todo!("Weights are expected to be numbers of BlockQuant");
27 }
28}
29
30#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
31pub enum KernelFormat {
32 #[default]
33 OIHW,
34 HWIO,
35 OHWI,
36}
37
38impl KernelFormat {
39 pub fn h_axis(&self) -> usize {
40 match self {
41 KernelFormat::OIHW => 2,
42 KernelFormat::HWIO => 0,
43 KernelFormat::OHWI => 1,
44 }
45 }
46
47 pub fn spatial_shape<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
48 &full_shape[self.h_axis()..][..full_shape.len() - 2]
49 }
50
51 pub fn hw<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
52 self.spatial_shape(full_shape)
53 }
54
55 pub fn i<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
56 match self {
57 KernelFormat::OIHW => &full_shape[1],
58 KernelFormat::HWIO => &full_shape[full_shape.len() - 2],
59 KernelFormat::OHWI => &full_shape[full_shape.len() - 1],
60 }
61 }
62
63 pub fn o_axis<D>(&self, full_shape: &[D]) -> usize {
64 match self {
65 KernelFormat::OIHW | KernelFormat::OHWI => 0,
66 KernelFormat::HWIO => full_shape.len() - 1,
67 }
68 }
69
70 pub fn i_axis<D>(&self, full_shape: &[D]) -> usize {
71 match self {
72 KernelFormat::OIHW => 1,
73 KernelFormat::OHWI => full_shape.len() - 1,
74 KernelFormat::HWIO => full_shape.len() - 2,
75 }
76 }
77
78 pub fn o<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
79 &full_shape[self.o_axis(full_shape)]
80 }
81
82 pub fn input_channels<'s, D: DimLike>(
83 &self,
84 full_kernel_shape: &'s [D],
85 group: usize,
86 ) -> Cow<'s, D> {
87 match self {
88 KernelFormat::OIHW => Cow::Owned(self.i(full_kernel_shape).clone() * group),
89 KernelFormat::HWIO | KernelFormat::OHWI => Cow::Borrowed(self.i(full_kernel_shape)),
90 }
91 }
92
93 pub fn output_channels<'s, D: DimLike>(
94 &self,
95 full_kernel_shape: &'s [D],
96 group: usize,
97 ) -> Cow<'s, D> {
98 match self {
99 KernelFormat::OIHW => Cow::Borrowed(self.o(full_kernel_shape)),
100 KernelFormat::HWIO | KernelFormat::OHWI => {
101 Cow::Owned(self.o(full_kernel_shape).clone() * group)
102 }
103 }
104 }
105
106 pub fn kernel_as_group_o_i_h_w_ops(
107 &self,
108 full_shape: &[impl DimLike],
109 group: usize,
110 ) -> TVec<AxisOp> {
111 let geo_rank = full_shape.len() - 2;
112 match self {
113 KernelFormat::HWIO => {
115 tvec!(
116 AxisOp::Reshape(
117 geo_rank,
118 tvec!(self.i(full_shape).to_dim()),
119 tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
120 ), AxisOp::Move(geo_rank, 0), AxisOp::Move(geo_rank + 2, 1), AxisOp::Move(geo_rank + 2, 2)
124 ) }
126 KernelFormat::OIHW => {
128 tvec!(AxisOp::Reshape(
129 0,
130 tvec!(self.o(full_shape).to_dim()),
131 tvec!(group.to_dim(), self.o(full_shape).to_dim() / group),
132 ))
133 }
134 KernelFormat::OHWI => {
136 tvec!(
137 AxisOp::Reshape(
138 geo_rank + 1,
139 tvec!(self.i(full_shape).to_dim()),
140 tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
141 ), AxisOp::Move(geo_rank + 1, 0), AxisOp::Move(geo_rank + 2, 2)
144 )
145 }
146 }
147 }
148
149 pub fn kernel_as_group_o_i_hw_ops(
150 &self,
151 full_shape: &[impl DimLike],
152 group: usize,
153 ) -> TVec<AxisOp> {
154 let mut ops = self.kernel_as_group_o_i_h_w_ops(full_shape, group);
155 if self.hw(full_shape).len() > 1 {
156 ops.push(AxisOp::Reshape(
157 3,
158 self.hw(full_shape).iter().map(|t| t.to_dim()).collect(),
159 tvec!(self.hw(full_shape).iter().map(|t| t.to_dim()).product()),
160 ));
161 }
162 ops
163 }
164
165 pub fn kernel_as_group_o_ihw_ops(
166 &self,
167 full_shape: &[impl DimLike],
168 group: usize,
169 ) -> TVec<AxisOp> {
170 let i = (self.input_channels(full_shape, group).into_owned() / group).to_dim();
171 let hw = self.hw(full_shape).iter().map(|t| t.to_dim()).product::<TDim>();
172 let mut ops = self.kernel_as_group_o_i_hw_ops(full_shape, group);
173 ops.push(AxisOp::Reshape(2, tvec!(i.clone(), hw.clone()), tvec!(i * hw)));
174 ops
175 }
176
177 pub fn kernel_as_group_o_i_hw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
178 let mut kernel = kernel.clone();
179 let ops = self.kernel_as_group_o_i_hw_ops(kernel.shape(), group);
180 for op in &ops {
181 op.change_tensor(&mut kernel, false)?;
182 }
183 Ok(kernel)
184 }
185
186 pub fn kernel_as_group_o_ihw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
187 let group_o_i_hw = self.kernel_as_group_o_i_hw(kernel, group)?;
188 Ok(group_o_i_hw.collapse_axis_with_next(2))
189 }
190}
191
192pub fn rewrite_kernel_conv_in_oihw(
193 _ctx: &(),
194 model: &TypedModel,
195 node: &TypedNode,
196 name: &str,
197 conv: &Conv,
198) -> TractResult<Option<TypedModelPatch>> {
199 rewrite_kernel_in_oihw(
200 model,
201 node,
202 name,
203 conv.kernel_fmt,
204 conv.group,
205 Box::new(Conv { kernel_fmt: KernelFormat::OIHW, ..conv.clone() }),
206 )
207}
208
209pub fn rewrite_kernel_deconv_in_oihw(
210 _ctx: &(),
211 model: &TypedModel,
212 node: &TypedNode,
213 name: &str,
214 conv: &Deconv,
215) -> TractResult<Option<TypedModelPatch>> {
216 rewrite_kernel_in_oihw(
217 model,
218 node,
219 name,
220 conv.kernel_format,
221 conv.group,
222 Box::new(Deconv { kernel_format: KernelFormat::OIHW, ..conv.clone() }),
223 )
224}
225
226fn rewrite_kernel_in_oihw(
227 model: &TypedModel,
228 node: &TypedNode,
229 name: &str,
230 fmt: KernelFormat,
231 group: usize,
232 new: Box<dyn TypedOp>,
233) -> TractResult<Option<TypedModelPatch>> {
234 rule_if!(fmt != KernelFormat::OIHW);
235 let mut patch = TypedModelPatch::default();
236 let mut wire = patch.taps(model, &node.inputs)?;
237 let prefix = format!("{name}.kernel_reorg");
238 for (ix, op) in fmt
239 .kernel_as_group_o_i_h_w_ops(&patch.outlet_fact(wire[1])?.shape, group)
240 .into_iter()
241 .enumerate()
242 {
243 wire[1] = patch.wire_node(format!("{prefix}.{ix}"), op, &[wire[1]])?[0];
244 }
245 wire[1] =
246 AxisOp::wire_collapse_axis(&mut patch, format!("{name}.kernel_reorg_go"), wire[1], 0)?[0];
247 wire = patch.wire_node(name, new, &wire)?;
248 patch.shunt_outside(model, node.id.into(), wire[0])?;
249 Ok(Some(patch))
250}