Skip to main content

tract_core/ops/cnn/conv/
mod.rs

1mod block_quant;
2#[allow(clippy::module_inception)]
3mod conv;
4mod depth_wise;
5mod im2col;
6mod lazy_im2col;
7mod q_sum_b;
8
9use crate::internal::*;
10use crate::ops::cnn::Deconv;
11
12pub use self::conv::Conv;
13pub use self::im2col::Im2Col;
14pub(crate) use self::q_sum_b::QSumB;
15
16#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
17pub enum KernelFormat {
18    #[default]
19    OIHW,
20    HWIO,
21    OHWI,
22}
23
24impl KernelFormat {
25    pub fn h_axis(&self) -> usize {
26        match self {
27            KernelFormat::OIHW => 2,
28            KernelFormat::HWIO => 0,
29            KernelFormat::OHWI => 1,
30        }
31    }
32
33    pub fn spatial_shape<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
34        &full_shape[self.h_axis()..][..full_shape.len() - 2]
35    }
36
37    pub fn hw<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
38        self.spatial_shape(full_shape)
39    }
40
41    pub fn i<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
42        match self {
43            KernelFormat::OIHW => &full_shape[1],
44            KernelFormat::HWIO => &full_shape[full_shape.len() - 2],
45            KernelFormat::OHWI => &full_shape[full_shape.len() - 1],
46        }
47    }
48
49    pub fn o_axis<D>(&self, full_shape: &[D]) -> usize {
50        match self {
51            KernelFormat::OIHW | KernelFormat::OHWI => 0,
52            KernelFormat::HWIO => full_shape.len() - 1,
53        }
54    }
55
56    pub fn i_axis<D>(&self, full_shape: &[D]) -> usize {
57        match self {
58            KernelFormat::OIHW => 1,
59            KernelFormat::OHWI => full_shape.len() - 1,
60            KernelFormat::HWIO => full_shape.len() - 2,
61        }
62    }
63
64    pub fn o<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
65        &full_shape[self.o_axis(full_shape)]
66    }
67
68    pub fn input_channels<'s, D: DimLike>(
69        &self,
70        full_kernel_shape: &'s [D],
71        group: usize,
72    ) -> Cow<'s, D> {
73        match self {
74            KernelFormat::OIHW => Cow::Owned(self.i(full_kernel_shape).clone() * group),
75            KernelFormat::HWIO | KernelFormat::OHWI => Cow::Borrowed(self.i(full_kernel_shape)),
76        }
77    }
78
79    pub fn output_channels<'s, D: DimLike>(
80        &self,
81        full_kernel_shape: &'s [D],
82        group: usize,
83    ) -> Cow<'s, D> {
84        match self {
85            KernelFormat::OIHW => Cow::Borrowed(self.o(full_kernel_shape)),
86            KernelFormat::HWIO | KernelFormat::OHWI => {
87                Cow::Owned(self.o(full_kernel_shape).clone() * group)
88            }
89        }
90    }
91
92    pub fn kernel_as_group_o_i_h_w_ops(
93        &self,
94        full_shape: &[impl DimLike],
95        group: usize,
96    ) -> TVec<AxisOp> {
97        let geo_rank = full_shape.len() - 2;
98        match self {
99            // g is on i
100            KernelFormat::HWIO => {
101                tvec!(
102                    AxisOp::Reshape(
103                        geo_rank,
104                        tvec!(self.i(full_shape).to_dim()),
105                        tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
106                    ), // h w g i o
107                    AxisOp::Move(geo_rank, 0),     // g h w i o
108                    AxisOp::Move(geo_rank + 2, 1), // g o h w i
109                    AxisOp::Move(geo_rank + 2, 2)
110                ) // g o i h w
111            }
112            // g is on o
113            KernelFormat::OIHW => {
114                tvec!(AxisOp::Reshape(
115                    0,
116                    tvec!(self.o(full_shape).to_dim()),
117                    tvec!(group.to_dim(), self.o(full_shape).to_dim() / group),
118                ))
119            }
120            // g is on i
121            KernelFormat::OHWI => {
122                tvec!(
123                    AxisOp::Reshape(
124                        geo_rank + 1,
125                        tvec!(self.i(full_shape).to_dim()),
126                        tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
127                    ), // o h w g i
128                    AxisOp::Move(geo_rank + 1, 0), // g o h w i
129                    AxisOp::Move(geo_rank + 2, 2)
130                )
131            }
132        }
133    }
134
135    pub fn kernel_as_group_o_i_hw_ops(
136        &self,
137        full_shape: &[impl DimLike],
138        group: usize,
139    ) -> TVec<AxisOp> {
140        let mut ops = self.kernel_as_group_o_i_h_w_ops(full_shape, group);
141        if self.hw(full_shape).len() > 1 {
142            ops.push(AxisOp::Reshape(
143                3,
144                self.hw(full_shape).iter().map(|t| t.to_dim()).collect(),
145                tvec!(self.hw(full_shape).iter().map(|t| t.to_dim()).product()),
146            ));
147        }
148        ops
149    }
150
151    pub fn kernel_as_group_o_ihw_ops(
152        &self,
153        full_shape: &[impl DimLike],
154        group: usize,
155    ) -> TVec<AxisOp> {
156        let i = (self.input_channels(full_shape, group).into_owned() / group).to_dim();
157        let hw = self.hw(full_shape).iter().map(|t| t.to_dim()).product::<TDim>();
158        let mut ops = self.kernel_as_group_o_i_hw_ops(full_shape, group);
159        ops.push(AxisOp::Reshape(2, tvec!(i.clone(), hw.clone()), tvec!(i * hw)));
160        ops
161    }
162
163    pub fn kernel_as_group_o_i_hw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
164        let mut kernel = kernel.clone();
165        let ops = self.kernel_as_group_o_i_hw_ops(kernel.shape(), group);
166        for op in &ops {
167            op.change_tensor(&mut kernel, false)?;
168        }
169        Ok(kernel)
170    }
171
172    pub fn kernel_as_group_o_ihw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
173        let group_o_i_hw = self.kernel_as_group_o_i_hw(kernel, group)?;
174        Ok(group_o_i_hw.collapse_axis_with_next(2))
175    }
176}
177
178pub fn rewrite_kernel_conv_in_oihw(
179    _ctx: &(),
180    model: &TypedModel,
181    node: &TypedNode,
182    name: &str,
183    conv: &Conv,
184) -> TractResult<Option<TypedModelPatch>> {
185    rewrite_kernel_in_oihw(
186        model,
187        node,
188        name,
189        conv.kernel_fmt,
190        conv.group,
191        Box::new(Conv { kernel_fmt: KernelFormat::OIHW, ..conv.clone() }),
192    )
193}
194
195pub fn rewrite_kernel_deconv_in_oihw(
196    _ctx: &(),
197    model: &TypedModel,
198    node: &TypedNode,
199    name: &str,
200    conv: &Deconv,
201) -> TractResult<Option<TypedModelPatch>> {
202    rewrite_kernel_in_oihw(
203        model,
204        node,
205        name,
206        conv.kernel_format,
207        conv.group,
208        Box::new(Deconv { kernel_format: KernelFormat::OIHW, ..conv.clone() }),
209    )
210}
211
212fn rewrite_kernel_in_oihw(
213    model: &TypedModel,
214    node: &TypedNode,
215    name: &str,
216    fmt: KernelFormat,
217    group: usize,
218    new: Box<dyn TypedOp>,
219) -> TractResult<Option<TypedModelPatch>> {
220    rule_if!(fmt != KernelFormat::OIHW);
221    let mut patch = TypedModelPatch::default();
222    let mut wire = patch.taps(model, &node.inputs)?;
223    let prefix = format!("{name}.kernel_reorg");
224    for (ix, op) in fmt
225        .kernel_as_group_o_i_h_w_ops(&patch.outlet_fact(wire[1])?.shape, group)
226        .into_iter()
227        .enumerate()
228    {
229        wire[1] = patch.wire_node(format!("{prefix}.{ix}"), op, &[wire[1]])?[0];
230    }
231    wire[1] =
232        AxisOp::wire_collapse_axis(&mut patch, format!("{name}.kernel_reorg_go"), wire[1], 0)?[0];
233    wire = patch.wire_node(name, new, &wire)?;
234    patch.shunt_outside(model, node.id.into(), wire[0])?;
235    Ok(Some(patch))
236}