tract_core/ops/cnn/conv/
mod.rs

1mod block_quant;
2#[allow(clippy::module_inception)]
3mod conv;
4mod depth_wise;
5mod im2col;
6mod lazy_im2col;
7mod q_sum_b;
8
9use tract_linalg::block_quant::BlockQuantFact;
10
11use crate::internal::*;
12
13pub use self::conv::Conv;
14pub use self::im2col::Im2Col;
15pub(crate) use self::q_sum_b::QSumB;
16
17fn block_quant_aware_weight_shape(weights: &TypedFact) -> TractResult<Cow<'_, ShapeFact>> {
18    if weights.datum_type.is_number() {
19        Ok(Cow::Borrowed(&weights.shape))
20    } else if let Some(bqf) =
21        weights.opaque_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
22    {
23        Ok(Cow::Owned(bqf.shape().into()))
24    } else {
25        todo!("Weights are expected to be numbers of BlockQuant");
26    }
27}
28
29#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
30pub enum KernelFormat {
31    #[default]
32    OIHW,
33    HWIO,
34    OHWI,
35}
36
37impl KernelFormat {
38    pub fn h_axis(&self) -> usize {
39        match self {
40            KernelFormat::OIHW => 2,
41            KernelFormat::HWIO => 0,
42            KernelFormat::OHWI => 1,
43        }
44    }
45
46    pub fn spatial_shape<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
47        &full_shape[self.h_axis()..][..full_shape.len() - 2]
48    }
49
50    pub fn hw<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
51        self.spatial_shape(full_shape)
52    }
53
54    pub fn i<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
55        match self {
56            KernelFormat::OIHW => &full_shape[1],
57            KernelFormat::HWIO => &full_shape[full_shape.len() - 2],
58            KernelFormat::OHWI => &full_shape[full_shape.len() - 1],
59        }
60    }
61
62    pub fn o_axis<D>(&self, full_shape: &[D]) -> usize {
63        match self {
64            KernelFormat::OIHW | KernelFormat::OHWI => 0,
65            KernelFormat::HWIO => full_shape.len() - 1,
66        }
67    }
68
69    pub fn o<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
70        &full_shape[self.o_axis(full_shape)]
71    }
72
73    pub fn input_channels<'s, D: DimLike>(
74        &self,
75        full_kernel_shape: &'s [D],
76        group: usize,
77    ) -> Cow<'s, D> {
78        match self {
79            KernelFormat::OIHW => Cow::Owned(self.i(full_kernel_shape).clone() * group),
80            KernelFormat::HWIO | KernelFormat::OHWI => Cow::Borrowed(self.i(full_kernel_shape)),
81        }
82    }
83
84    pub fn output_channels<'s, D: DimLike>(
85        &self,
86        full_kernel_shape: &'s [D],
87        group: usize,
88    ) -> Cow<'s, D> {
89        match self {
90            KernelFormat::OIHW => Cow::Borrowed(self.o(full_kernel_shape)),
91            KernelFormat::HWIO | KernelFormat::OHWI => {
92                Cow::Owned(self.o(full_kernel_shape).clone() * group)
93            }
94        }
95    }
96
97    pub fn kernel_as_group_o_i_h_w_ops(
98        &self,
99        full_shape: &[impl DimLike],
100        group: usize,
101    ) -> TVec<AxisOp> {
102        let geo_rank = full_shape.len() - 2;
103        match self {
104            // g is on i
105            KernelFormat::HWIO => {
106                tvec!(
107                    AxisOp::Reshape(
108                        geo_rank,
109                        tvec!(self.i(full_shape).to_dim()),
110                        tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
111                    ), // h w g i o
112                    AxisOp::Move(geo_rank, 0),     // g h w i o
113                    AxisOp::Move(geo_rank + 2, 1), // g o h w i
114                    AxisOp::Move(geo_rank + 2, 2)
115                ) // g o i h w
116            }
117            // g is on o
118            KernelFormat::OIHW => {
119                tvec!(AxisOp::Reshape(
120                    0,
121                    tvec!(self.o(full_shape).to_dim()),
122                    tvec!(group.to_dim(), self.o(full_shape).to_dim() / group),
123                ))
124            }
125            // g is on i
126            KernelFormat::OHWI => {
127                tvec!(
128                    AxisOp::Reshape(
129                        geo_rank + 1,
130                        tvec!(self.i(full_shape).to_dim()),
131                        tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
132                    ), // o h w g i
133                    AxisOp::Move(geo_rank + 1, 0), // g o h w i
134                    AxisOp::Move(geo_rank + 2, 2)
135                )
136            }
137        }
138    }
139
140    pub fn kernel_as_group_o_i_hw_ops(
141        &self,
142        full_shape: &[impl DimLike],
143        group: usize,
144    ) -> TVec<AxisOp> {
145        let mut ops = self.kernel_as_group_o_i_h_w_ops(full_shape, group);
146        if self.hw(full_shape).len() > 1 {
147            ops.push(AxisOp::Reshape(
148                3,
149                self.hw(full_shape).iter().map(|t| t.to_dim()).collect(),
150                tvec!(self.hw(full_shape).iter().map(|t| t.to_dim()).product()),
151            ));
152        }
153        ops
154    }
155
156    pub fn kernel_as_group_o_ihw_ops(
157        &self,
158        full_shape: &[impl DimLike],
159        group: usize,
160    ) -> TVec<AxisOp> {
161        let i = (self.input_channels(full_shape, group).into_owned() / group).to_dim();
162        let hw = self.hw(full_shape).iter().map(|t| t.to_dim()).product::<TDim>();
163        let mut ops = self.kernel_as_group_o_i_hw_ops(full_shape, group);
164        ops.push(AxisOp::Reshape(2, tvec!(i.clone(), hw.clone()), tvec!(i * hw)));
165        ops
166    }
167
168    pub fn kernel_as_group_o_i_hw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
169        let mut kernel = kernel.clone();
170        let ops = self.kernel_as_group_o_i_hw_ops(kernel.shape(), group);
171        for op in &ops {
172            op.change_tensor(&mut kernel, false)?;
173        }
174        Ok(kernel)
175    }
176
177    pub fn kernel_as_group_o_ihw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
178        let group_o_i_hw = self.kernel_as_group_o_i_hw(kernel, group)?;
179        Ok(group_o_i_hw.collapse_axis_with_next(2))
180    }
181}