1use crate::{CnnError, CnnResult, Tensor};
9
10use super::{Conv2d, Layer};
11
12#[derive(Debug, Clone)]
19pub struct QuantizedDepthwiseConv2d {
20 weights_q: Vec<i8>,
22
23 weight_scales: Vec<f32>,
25
26 bias_q: Vec<i32>,
28
29 bias_f32: Vec<f32>,
31
32 channels: usize,
34 kernel_size: usize,
35 stride: usize,
36 padding: usize,
37}
38
39impl QuantizedDepthwiseConv2d {
40 pub fn from_fp32(
51 channels: usize,
52 kernel_size: usize,
53 weights: &[f32],
54 bias: Option<&[f32]>,
55 stride: usize,
56 padding: usize,
57 input_scale: f32,
58 ) -> Self {
59 let mut weight_scales = vec![0.0f32; channels];
61
62 for c in 0..channels {
63 let mut max_abs = 0.0f32;
64 for kh in 0..kernel_size {
65 for kw in 0..kernel_size {
66 let idx = c * kernel_size * kernel_size + kh * kernel_size + kw;
67 max_abs = max_abs.max(weights[idx].abs());
68 }
69 }
70 weight_scales[c] = if max_abs > 0.0 {
71 max_abs / 127.0
72 } else {
73 1.0
74 };
75 }
76
77 let mut weights_q = vec![0i8; weights.len()];
79 for c in 0..channels {
80 let scale = weight_scales[c];
81 for kh in 0..kernel_size {
82 for kw in 0..kernel_size {
83 let idx = c * kernel_size * kernel_size + kh * kernel_size + kw;
84 let w_q = (weights[idx] / scale).round().clamp(-127.0, 127.0) as i8;
85 weights_q[idx] = w_q;
86 }
87 }
88 }
89
90 let bias_f32 = bias.map(|b| b.to_vec()).unwrap_or_else(|| vec![0.0; channels]);
92 let mut bias_q = vec![0i32; channels];
93
94 for c in 0..channels {
95 let combined_scale = input_scale * weight_scales[c];
96 bias_q[c] = if combined_scale > 0.0 {
97 (bias_f32[c] / combined_scale).round() as i32
98 } else {
99 0
100 };
101 }
102
103 Self {
104 weights_q,
105 weight_scales,
106 bias_q,
107 bias_f32,
108 channels,
109 kernel_size,
110 stride,
111 padding,
112 }
113 }
114
115 pub fn forward_int8(
123 &self,
124 input: &[u8],
125 input_shape: &[usize],
126 input_scale: f32,
127 input_zero_point: u8,
128 ) -> CnnResult<Tensor> {
129 if input_shape.len() != 4 {
130 return Err(CnnError::invalid_shape(
131 "4D input (NHWC)",
132 format!("{}D", input_shape.len())
133 ));
134 }
135
136 let batch = input_shape[0];
137 let in_h = input_shape[1];
138 let in_w = input_shape[2];
139 let in_c = input_shape[3];
140
141 if in_c != self.channels {
142 return Err(CnnError::invalid_shape(
143 format!("{} channels", self.channels),
144 format!("{} channels", in_c)
145 ));
146 }
147
148 let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
149 let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
150
151 let mut output_i32 = vec![0i32; batch * out_h * out_w * self.channels];
152
153 for b in 0..batch {
155 let batch_in_size = in_h * in_w * in_c;
156 let batch_out_size = out_h * out_w * self.channels;
157
158 let input_slice = &input[b * batch_in_size..(b + 1) * batch_in_size];
159 let output_slice = &mut output_i32[b * batch_out_size..(b + 1) * batch_out_size];
160
161 self.depthwise_conv_int8_scalar(
162 input_slice,
163 input_zero_point as i32,
164 output_slice,
165 in_h, in_w, out_h, out_w,
166 );
167 }
168
169 let output_f32 = self.dequantize_output(&output_i32, input_scale);
171
172 Tensor::from_data(
173 output_f32,
174 &[batch, out_h, out_w, self.channels],
175 )
176 }
177
178 fn depthwise_conv_int8_scalar(
180 &self,
181 input: &[u8],
182 input_zero_point: i32,
183 output: &mut [i32],
184 in_h: usize,
185 in_w: usize,
186 out_h: usize,
187 out_w: usize,
188 ) {
189 let ks = self.kernel_size;
190
191 let mut weight_sums = vec![0i32; self.channels];
193 for c in 0..self.channels {
194 let mut sum = 0i32;
195 for kh in 0..ks {
196 for kw in 0..ks {
197 let idx = c * ks * ks + kh * ks + kw;
198 sum += self.weights_q[idx] as i32;
199 }
200 }
201 weight_sums[c] = sum;
202 }
203
204 for oh in 0..out_h {
206 for ow in 0..out_w {
207 for c in 0..self.channels {
208 let mut acc = self.bias_q[c] - input_zero_point * weight_sums[c];
210
211 for kh in 0..ks {
213 for kw in 0..ks {
214 let ih = (oh * self.stride + kh) as isize - self.padding as isize;
215 let iw = (ow * self.stride + kw) as isize - self.padding as isize;
216
217 if ih >= 0 && ih < in_h as isize && iw >= 0 && iw < in_w as isize {
218 let ih = ih as usize;
219 let iw = iw as usize;
220
221 let input_idx = (ih * in_w + iw) * self.channels + c;
222 let weight_idx = c * ks * ks + kh * ks + kw;
223
224 acc += (input[input_idx] as i32) * (self.weights_q[weight_idx] as i32);
225 }
226 }
227 }
228
229 output[(oh * out_w + ow) * self.channels + c] = acc;
230 }
231 }
232 }
233 }
234
235 fn dequantize_output(&self, acc: &[i32], input_scale: f32) -> Vec<f32> {
237 let mut output = vec![0.0f32; acc.len()];
238
239 for (i, &val) in acc.iter().enumerate() {
240 let c = i % self.channels;
241 let scale = input_scale * self.weight_scales[c];
242 output[i] = val as f32 * scale;
243 }
244
245 output
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_quantized_depthwise_conv2d_creation() {
255 let channels = 32;
256 let kernel_size = 3;
257 let weights = vec![0.1f32; channels * kernel_size * kernel_size];
258 let bias_vec = vec![0.0f32; channels];
259
260 let qconv = QuantizedDepthwiseConv2d::from_fp32(
261 channels,
262 kernel_size,
263 &weights,
264 Some(&bias_vec),
265 1,
266 1,
267 0.01,
268 );
269
270 assert_eq!(qconv.channels, 32);
271 assert_eq!(qconv.kernel_size, 3);
272 }
273
274 #[test]
275 fn test_quantized_depthwise_conv2d_forward() {
276 let channels = 16;
277 let kernel_size = 3;
278 let weights = vec![0.1f32; channels * kernel_size * kernel_size];
279
280 let qconv = QuantizedDepthwiseConv2d::from_fp32(
281 channels,
282 kernel_size,
283 &weights,
284 None,
285 1,
286 1,
287 0.01,
288 );
289
290 let input = vec![128u8; 1 * 8 * 8 * channels];
291 let input_shape = &[1, 8, 8, channels];
292
293 let output = qconv.forward_int8(&input, input_shape, 0.01, 128).unwrap();
294
295 assert_eq!(output.shape(), &[1, 8, 8, channels]);
296 }
297}