Skip to main content

tract_core/ops/cnn/conv/
mod.rs

1mod block_quant;
2mod blocked;
3#[allow(clippy::module_inception)]
4mod conv;
5mod depth_wise;
6mod im2col;
7mod lazy_im2col;
8mod q_sum_b;
9
10use crate::internal::*;
11use crate::ops::cnn::Deconv;
12
13pub use self::blocked::BlockedConv;
14pub use self::conv::Conv;
15pub use self::im2col::Im2Col;
16pub(crate) use self::q_sum_b::QSumB;
17
18#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
19pub enum KernelFormat {
20    #[default]
21    OIHW,
22    HWIO,
23    OHWI,
24}
25
26impl KernelFormat {
27    pub fn h_axis(&self) -> usize {
28        match self {
29            KernelFormat::OIHW => 2,
30            KernelFormat::HWIO => 0,
31            KernelFormat::OHWI => 1,
32        }
33    }
34
35    pub fn spatial_shape<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
36        &full_shape[self.h_axis()..][..full_shape.len() - 2]
37    }
38
39    pub fn hw<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
40        self.spatial_shape(full_shape)
41    }
42
43    pub fn i<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
44        match self {
45            KernelFormat::OIHW => &full_shape[1],
46            KernelFormat::HWIO => &full_shape[full_shape.len() - 2],
47            KernelFormat::OHWI => &full_shape[full_shape.len() - 1],
48        }
49    }
50
51    pub fn o_axis<D>(&self, full_shape: &[D]) -> usize {
52        match self {
53            KernelFormat::OIHW | KernelFormat::OHWI => 0,
54            KernelFormat::HWIO => full_shape.len() - 1,
55        }
56    }
57
58    pub fn i_axis<D>(&self, full_shape: &[D]) -> usize {
59        match self {
60            KernelFormat::OIHW => 1,
61            KernelFormat::OHWI => full_shape.len() - 1,
62            KernelFormat::HWIO => full_shape.len() - 2,
63        }
64    }
65
66    pub fn o<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
67        &full_shape[self.o_axis(full_shape)]
68    }
69
70    pub fn input_channels<'s, D: DimLike>(
71        &self,
72        full_kernel_shape: &'s [D],
73        group: usize,
74    ) -> Cow<'s, D> {
75        match self {
76            KernelFormat::OIHW => Cow::Owned(self.i(full_kernel_shape).clone() * group),
77            KernelFormat::HWIO | KernelFormat::OHWI => Cow::Borrowed(self.i(full_kernel_shape)),
78        }
79    }
80
81    pub fn output_channels<'s, D: DimLike>(
82        &self,
83        full_kernel_shape: &'s [D],
84        group: usize,
85    ) -> Cow<'s, D> {
86        match self {
87            KernelFormat::OIHW => Cow::Borrowed(self.o(full_kernel_shape)),
88            KernelFormat::HWIO | KernelFormat::OHWI => {
89                Cow::Owned(self.o(full_kernel_shape).clone() * group)
90            }
91        }
92    }
93
94    pub fn kernel_as_group_o_i_h_w_ops(
95        &self,
96        full_shape: &[impl DimLike],
97        group: usize,
98    ) -> TVec<AxisOp> {
99        let geo_rank = full_shape.len() - 2;
100        match self {
101            // g is on i
102            KernelFormat::HWIO => {
103                tvec!(
104                    AxisOp::Reshape(
105                        geo_rank,
106                        tvec!(self.i(full_shape).to_dim()),
107                        tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
108                    ), // h w g i o
109                    AxisOp::Move(geo_rank, 0),     // g h w i o
110                    AxisOp::Move(geo_rank + 2, 1), // g o h w i
111                    AxisOp::Move(geo_rank + 2, 2)
112                ) // g o i h w
113            }
114            // g is on o
115            KernelFormat::OIHW => {
116                tvec!(AxisOp::Reshape(
117                    0,
118                    tvec!(self.o(full_shape).to_dim()),
119                    tvec!(group.to_dim(), self.o(full_shape).to_dim() / group),
120                ))
121            }
122            // g is on i
123            KernelFormat::OHWI => {
124                tvec!(
125                    AxisOp::Reshape(
126                        geo_rank + 1,
127                        tvec!(self.i(full_shape).to_dim()),
128                        tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
129                    ), // o h w g i
130                    AxisOp::Move(geo_rank + 1, 0), // g o h w i
131                    AxisOp::Move(geo_rank + 2, 2)
132                )
133            }
134        }
135    }
136
137    pub fn kernel_as_group_o_i_hw_ops(
138        &self,
139        full_shape: &[impl DimLike],
140        group: usize,
141    ) -> TVec<AxisOp> {
142        let mut ops = self.kernel_as_group_o_i_h_w_ops(full_shape, group);
143        if self.hw(full_shape).len() > 1 {
144            ops.push(AxisOp::Reshape(
145                3,
146                self.hw(full_shape).iter().map(|t| t.to_dim()).collect(),
147                tvec!(self.hw(full_shape).iter().map(|t| t.to_dim()).product()),
148            ));
149        }
150        ops
151    }
152
153    pub fn kernel_as_group_o_ihw_ops(
154        &self,
155        full_shape: &[impl DimLike],
156        group: usize,
157    ) -> TVec<AxisOp> {
158        let i = (self.input_channels(full_shape, group).into_owned() / group).to_dim();
159        let hw = self.hw(full_shape).iter().map(|t| t.to_dim()).product::<TDim>();
160        let mut ops = self.kernel_as_group_o_i_hw_ops(full_shape, group);
161        ops.push(AxisOp::Reshape(2, tvec!(i.clone(), hw.clone()), tvec!(i * hw)));
162        ops
163    }
164
165    pub fn kernel_as_group_o_i_hw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
166        let mut kernel = kernel.clone();
167        let ops = self.kernel_as_group_o_i_hw_ops(kernel.shape(), group);
168        for op in &ops {
169            op.change_tensor(&mut kernel, false)?;
170        }
171        Ok(kernel)
172    }
173
174    pub fn kernel_as_group_o_ihw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
175        let group_o_i_hw = self.kernel_as_group_o_i_hw(kernel, group)?;
176        Ok(group_o_i_hw.collapse_axis_with_next(2))
177    }
178}
179
180pub fn rewrite_kernel_conv_in_oihw(
181    _ctx: &(),
182    model: &TypedModel,
183    node: &TypedNode,
184    name: &str,
185    conv: &Conv,
186) -> TractResult<Option<TypedModelPatch>> {
187    rewrite_kernel_in_oihw(
188        model,
189        node,
190        name,
191        conv.kernel_fmt,
192        conv.group,
193        Box::new(Conv { kernel_fmt: KernelFormat::OIHW, ..conv.clone() }),
194    )
195}
196
197pub fn rewrite_kernel_deconv_in_oihw(
198    _ctx: &(),
199    model: &TypedModel,
200    node: &TypedNode,
201    name: &str,
202    conv: &Deconv,
203) -> TractResult<Option<TypedModelPatch>> {
204    rewrite_kernel_in_oihw(
205        model,
206        node,
207        name,
208        conv.kernel_format,
209        conv.group,
210        Box::new(Deconv { kernel_format: KernelFormat::OIHW, ..conv.clone() }),
211    )
212}
213
214fn rewrite_kernel_in_oihw(
215    model: &TypedModel,
216    node: &TypedNode,
217    name: &str,
218    fmt: KernelFormat,
219    group: usize,
220    new: Box<dyn TypedOp>,
221) -> TractResult<Option<TypedModelPatch>> {
222    rule_if!(fmt != KernelFormat::OIHW);
223    let mut patch = TypedModelPatch::default();
224    let mut wire = patch.taps(model, &node.inputs)?;
225    let prefix = format!("{name}.kernel_reorg");
226    for (ix, op) in fmt
227        .kernel_as_group_o_i_h_w_ops(&patch.outlet_fact(wire[1])?.shape, group)
228        .into_iter()
229        .enumerate()
230    {
231        wire[1] = patch.wire_node(format!("{prefix}.{ix}"), op, &[wire[1]])?[0];
232    }
233    wire[1] =
234        AxisOp::wire_collapse_axis(&mut patch, format!("{name}.kernel_reorg_go"), wire[1], 0)?[0];
235    wire = patch.wire_node(name, new, &wire)?;
236    patch.shunt_outside(model, node.id.into(), wire[0])?;
237    Ok(Some(patch))
238}