1mod block_quant;
2#[allow(clippy::module_inception)]
3mod conv;
4mod depth_wise;
5mod im2col;
6mod lazy_im2col;
7mod q_sum_b;
8
9use tract_linalg::block_quant::BlockQuantFact;
10
11use crate::internal::*;
12
13pub use self::conv::Conv;
14pub use self::im2col::Im2Col;
15pub(crate) use self::q_sum_b::QSumB;
16
17fn block_quant_aware_weight_shape(weights: &TypedFact) -> TractResult<Cow<'_, ShapeFact>> {
18 if weights.datum_type.is_number() {
19 Ok(Cow::Borrowed(&weights.shape))
20 } else if let Some(bqf) =
21 weights.opaque_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
22 {
23 Ok(Cow::Owned(bqf.shape().into()))
24 } else {
25 todo!("Weights are expected to be numbers of BlockQuant");
26 }
27}
28
29#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
30pub enum KernelFormat {
31 #[default]
32 OIHW,
33 HWIO,
34 OHWI,
35}
36
37impl KernelFormat {
38 pub fn h_axis(&self) -> usize {
39 match self {
40 KernelFormat::OIHW => 2,
41 KernelFormat::HWIO => 0,
42 KernelFormat::OHWI => 1,
43 }
44 }
45
46 pub fn spatial_shape<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
47 &full_shape[self.h_axis()..][..full_shape.len() - 2]
48 }
49
50 pub fn hw<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
51 self.spatial_shape(full_shape)
52 }
53
54 pub fn i<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
55 match self {
56 KernelFormat::OIHW => &full_shape[1],
57 KernelFormat::HWIO => &full_shape[full_shape.len() - 2],
58 KernelFormat::OHWI => &full_shape[full_shape.len() - 1],
59 }
60 }
61
62 pub fn o_axis<D>(&self, full_shape: &[D]) -> usize {
63 match self {
64 KernelFormat::OIHW | KernelFormat::OHWI => 0,
65 KernelFormat::HWIO => full_shape.len() - 1,
66 }
67 }
68
69 pub fn o<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
70 &full_shape[self.o_axis(full_shape)]
71 }
72
73 pub fn input_channels<'s, D: DimLike>(
74 &self,
75 full_kernel_shape: &'s [D],
76 group: usize,
77 ) -> Cow<'s, D> {
78 match self {
79 KernelFormat::OIHW => Cow::Owned(self.i(full_kernel_shape).clone() * group),
80 KernelFormat::HWIO | KernelFormat::OHWI => Cow::Borrowed(self.i(full_kernel_shape)),
81 }
82 }
83
84 pub fn output_channels<'s, D: DimLike>(
85 &self,
86 full_kernel_shape: &'s [D],
87 group: usize,
88 ) -> Cow<'s, D> {
89 match self {
90 KernelFormat::OIHW => Cow::Borrowed(self.o(full_kernel_shape)),
91 KernelFormat::HWIO | KernelFormat::OHWI => {
92 Cow::Owned(self.o(full_kernel_shape).clone() * group)
93 }
94 }
95 }
96
97 pub fn kernel_as_group_o_i_h_w_ops(
98 &self,
99 full_shape: &[impl DimLike],
100 group: usize,
101 ) -> TVec<AxisOp> {
102 let geo_rank = full_shape.len() - 2;
103 match self {
104 KernelFormat::HWIO => {
106 tvec!(
107 AxisOp::Reshape(
108 geo_rank,
109 tvec!(self.i(full_shape).to_dim()),
110 tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
111 ), AxisOp::Move(geo_rank, 0), AxisOp::Move(geo_rank + 2, 1), AxisOp::Move(geo_rank + 2, 2)
115 ) }
117 KernelFormat::OIHW => {
119 tvec!(AxisOp::Reshape(
120 0,
121 tvec!(self.o(full_shape).to_dim()),
122 tvec!(group.to_dim(), self.o(full_shape).to_dim() / group),
123 ))
124 }
125 KernelFormat::OHWI => {
127 tvec!(
128 AxisOp::Reshape(
129 geo_rank + 1,
130 tvec!(self.i(full_shape).to_dim()),
131 tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
132 ), AxisOp::Move(geo_rank + 1, 0), AxisOp::Move(geo_rank + 2, 2)
135 )
136 }
137 }
138 }
139
140 pub fn kernel_as_group_o_i_hw_ops(
141 &self,
142 full_shape: &[impl DimLike],
143 group: usize,
144 ) -> TVec<AxisOp> {
145 let mut ops = self.kernel_as_group_o_i_h_w_ops(full_shape, group);
146 if self.hw(full_shape).len() > 1 {
147 ops.push(AxisOp::Reshape(
148 3,
149 self.hw(full_shape).iter().map(|t| t.to_dim()).collect(),
150 tvec!(self.hw(full_shape).iter().map(|t| t.to_dim()).product()),
151 ));
152 }
153 ops
154 }
155
156 pub fn kernel_as_group_o_ihw_ops(
157 &self,
158 full_shape: &[impl DimLike],
159 group: usize,
160 ) -> TVec<AxisOp> {
161 let i = (self.input_channels(full_shape, group).into_owned() / group).to_dim();
162 let hw = self.hw(full_shape).iter().map(|t| t.to_dim()).product::<TDim>();
163 let mut ops = self.kernel_as_group_o_i_hw_ops(full_shape, group);
164 ops.push(AxisOp::Reshape(2, tvec!(i.clone(), hw.clone()), tvec!(i * hw)));
165 ops
166 }
167
168 pub fn kernel_as_group_o_i_hw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
169 let mut kernel = kernel.clone();
170 let ops = self.kernel_as_group_o_i_hw_ops(kernel.shape(), group);
171 for op in &ops {
172 op.change_tensor(&mut kernel, false)?;
173 }
174 Ok(kernel)
175 }
176
177 pub fn kernel_as_group_o_ihw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
178 let group_o_i_hw = self.kernel_as_group_o_i_hw(kernel, group)?;
179 Ok(group_o_i_hw.collapse_axis_with_next(2))
180 }
181}