Skip to main content

tract_core/ops/cnn/conv/
mod.rs

1mod block_quant;
2#[allow(clippy::module_inception)]
3mod conv;
4mod depth_wise;
5mod im2col;
6mod lazy_im2col;
7mod q_sum_b;
8
9use tract_linalg::block_quant::BlockQuantFact;
10
11use crate::internal::*;
12use crate::ops::cnn::Deconv;
13
14pub use self::conv::Conv;
15pub use self::im2col::Im2Col;
16pub(crate) use self::q_sum_b::QSumB;
17
18fn block_quant_aware_weight_shape(weights: &TypedFact) -> TractResult<Cow<'_, ShapeFact>> {
19    if weights.datum_type.is_number() {
20        Ok(Cow::Borrowed(&weights.shape))
21    } else if let Some(bqf) =
22        weights.opaque_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
23    {
24        Ok(Cow::Owned(bqf.shape().into()))
25    } else {
26        todo!("Weights are expected to be numbers of BlockQuant");
27    }
28}
29
30#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
31pub enum KernelFormat {
32    #[default]
33    OIHW,
34    HWIO,
35    OHWI,
36}
37
38impl KernelFormat {
39    pub fn h_axis(&self) -> usize {
40        match self {
41            KernelFormat::OIHW => 2,
42            KernelFormat::HWIO => 0,
43            KernelFormat::OHWI => 1,
44        }
45    }
46
47    pub fn spatial_shape<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
48        &full_shape[self.h_axis()..][..full_shape.len() - 2]
49    }
50
51    pub fn hw<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
52        self.spatial_shape(full_shape)
53    }
54
55    pub fn i<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
56        match self {
57            KernelFormat::OIHW => &full_shape[1],
58            KernelFormat::HWIO => &full_shape[full_shape.len() - 2],
59            KernelFormat::OHWI => &full_shape[full_shape.len() - 1],
60        }
61    }
62
63    pub fn o_axis<D>(&self, full_shape: &[D]) -> usize {
64        match self {
65            KernelFormat::OIHW | KernelFormat::OHWI => 0,
66            KernelFormat::HWIO => full_shape.len() - 1,
67        }
68    }
69
70    pub fn i_axis<D>(&self, full_shape: &[D]) -> usize {
71        match self {
72            KernelFormat::OIHW => 1,
73            KernelFormat::OHWI => full_shape.len() - 1,
74            KernelFormat::HWIO => full_shape.len() - 2,
75        }
76    }
77
78    pub fn o<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
79        &full_shape[self.o_axis(full_shape)]
80    }
81
82    pub fn input_channels<'s, D: DimLike>(
83        &self,
84        full_kernel_shape: &'s [D],
85        group: usize,
86    ) -> Cow<'s, D> {
87        match self {
88            KernelFormat::OIHW => Cow::Owned(self.i(full_kernel_shape).clone() * group),
89            KernelFormat::HWIO | KernelFormat::OHWI => Cow::Borrowed(self.i(full_kernel_shape)),
90        }
91    }
92
93    pub fn output_channels<'s, D: DimLike>(
94        &self,
95        full_kernel_shape: &'s [D],
96        group: usize,
97    ) -> Cow<'s, D> {
98        match self {
99            KernelFormat::OIHW => Cow::Borrowed(self.o(full_kernel_shape)),
100            KernelFormat::HWIO | KernelFormat::OHWI => {
101                Cow::Owned(self.o(full_kernel_shape).clone() * group)
102            }
103        }
104    }
105
106    pub fn kernel_as_group_o_i_h_w_ops(
107        &self,
108        full_shape: &[impl DimLike],
109        group: usize,
110    ) -> TVec<AxisOp> {
111        let geo_rank = full_shape.len() - 2;
112        match self {
113            // g is on i
114            KernelFormat::HWIO => {
115                tvec!(
116                    AxisOp::Reshape(
117                        geo_rank,
118                        tvec!(self.i(full_shape).to_dim()),
119                        tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
120                    ), // h w g i o
121                    AxisOp::Move(geo_rank, 0),     // g h w i o
122                    AxisOp::Move(geo_rank + 2, 1), // g o h w i
123                    AxisOp::Move(geo_rank + 2, 2)
124                ) // g o i h w
125            }
126            // g is on o
127            KernelFormat::OIHW => {
128                tvec!(AxisOp::Reshape(
129                    0,
130                    tvec!(self.o(full_shape).to_dim()),
131                    tvec!(group.to_dim(), self.o(full_shape).to_dim() / group),
132                ))
133            }
134            // g is on i
135            KernelFormat::OHWI => {
136                tvec!(
137                    AxisOp::Reshape(
138                        geo_rank + 1,
139                        tvec!(self.i(full_shape).to_dim()),
140                        tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
141                    ), // o h w g i
142                    AxisOp::Move(geo_rank + 1, 0), // g o h w i
143                    AxisOp::Move(geo_rank + 2, 2)
144                )
145            }
146        }
147    }
148
149    pub fn kernel_as_group_o_i_hw_ops(
150        &self,
151        full_shape: &[impl DimLike],
152        group: usize,
153    ) -> TVec<AxisOp> {
154        let mut ops = self.kernel_as_group_o_i_h_w_ops(full_shape, group);
155        if self.hw(full_shape).len() > 1 {
156            ops.push(AxisOp::Reshape(
157                3,
158                self.hw(full_shape).iter().map(|t| t.to_dim()).collect(),
159                tvec!(self.hw(full_shape).iter().map(|t| t.to_dim()).product()),
160            ));
161        }
162        ops
163    }
164
165    pub fn kernel_as_group_o_ihw_ops(
166        &self,
167        full_shape: &[impl DimLike],
168        group: usize,
169    ) -> TVec<AxisOp> {
170        let i = (self.input_channels(full_shape, group).into_owned() / group).to_dim();
171        let hw = self.hw(full_shape).iter().map(|t| t.to_dim()).product::<TDim>();
172        let mut ops = self.kernel_as_group_o_i_hw_ops(full_shape, group);
173        ops.push(AxisOp::Reshape(2, tvec!(i.clone(), hw.clone()), tvec!(i * hw)));
174        ops
175    }
176
177    pub fn kernel_as_group_o_i_hw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
178        let mut kernel = kernel.clone();
179        let ops = self.kernel_as_group_o_i_hw_ops(kernel.shape(), group);
180        for op in &ops {
181            op.change_tensor(&mut kernel, false)?;
182        }
183        Ok(kernel)
184    }
185
186    pub fn kernel_as_group_o_ihw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
187        let group_o_i_hw = self.kernel_as_group_o_i_hw(kernel, group)?;
188        Ok(group_o_i_hw.collapse_axis_with_next(2))
189    }
190}
191
192pub fn rewrite_kernel_conv_in_oihw(
193    _ctx: &(),
194    model: &TypedModel,
195    node: &TypedNode,
196    name: &str,
197    conv: &Conv,
198) -> TractResult<Option<TypedModelPatch>> {
199    rewrite_kernel_in_oihw(
200        model,
201        node,
202        name,
203        conv.kernel_fmt,
204        conv.group,
205        Box::new(Conv { kernel_fmt: KernelFormat::OIHW, ..conv.clone() }),
206    )
207}
208
209pub fn rewrite_kernel_deconv_in_oihw(
210    _ctx: &(),
211    model: &TypedModel,
212    node: &TypedNode,
213    name: &str,
214    conv: &Deconv,
215) -> TractResult<Option<TypedModelPatch>> {
216    rewrite_kernel_in_oihw(
217        model,
218        node,
219        name,
220        conv.kernel_format,
221        conv.group,
222        Box::new(Deconv { kernel_format: KernelFormat::OIHW, ..conv.clone() }),
223    )
224}
225
226fn rewrite_kernel_in_oihw(
227    model: &TypedModel,
228    node: &TypedNode,
229    name: &str,
230    fmt: KernelFormat,
231    group: usize,
232    new: Box<dyn TypedOp>,
233) -> TractResult<Option<TypedModelPatch>> {
234    rule_if!(fmt != KernelFormat::OIHW);
235    let mut patch = TypedModelPatch::default();
236    let mut wire = patch.taps(model, &node.inputs)?;
237    let prefix = format!("{name}.kernel_reorg");
238    for (ix, op) in fmt
239        .kernel_as_group_o_i_h_w_ops(&patch.outlet_fact(wire[1])?.shape, group)
240        .into_iter()
241        .enumerate()
242    {
243        wire[1] = patch.wire_node(format!("{prefix}.{ix}"), op, &[wire[1]])?[0];
244    }
245    wire[1] =
246        AxisOp::wire_collapse_axis(&mut patch, format!("{name}.kernel_reorg_go"), wire[1], 0)?[0];
247    wire = patch.wire_node(name, new, &wire)?;
248    patch.shunt_outside(model, node.id.into(), wire[0])?;
249    Ok(Some(patch))
250}