1use crate::{
10 simd::quantize::QuantParams,
11 CnnError, CnnResult, Tensor,
12};
13
14use super::{Conv2d, Layer, TensorShape};
15
16#[cfg(target_arch = "x86_64")]
17use std::arch::x86_64::*;
18
19#[derive(Debug, Clone)]
24pub struct QuantizedConv2d {
25 weights_q: Vec<i8>,
27
28 weight_scales: Vec<f32>,
30
31 bias_q: Vec<i32>,
34
35 bias_f32: Vec<f32>,
37
38 in_channels: usize,
40 out_channels: usize,
41 kernel_size: usize,
42 stride: usize,
43 padding: usize,
44 groups: usize,
45}
46
47impl QuantizedConv2d {
48 pub fn from_fp32(
55 conv: &Conv2d,
56 input_scale: f32,
57 input_zero_point: i32,
58 ) -> Self {
59 let out_c = conv.out_channels();
60 let in_c = conv.in_channels();
61 let ks = conv.kernel_size();
62
63 let mut weight_scales = vec![0.0f32; out_c];
65 let weights = conv.weights();
66
67 for oc in 0..out_c {
68 let mut max_abs = 0.0f32;
69 for ic in 0..in_c {
70 for kh in 0..ks {
71 for kw in 0..ks {
72 let idx = oc * ks * ks * in_c + kh * ks * in_c + kw * in_c + ic;
73 max_abs = max_abs.max(weights[idx].abs());
74 }
75 }
76 }
77 weight_scales[oc] = if max_abs > 0.0 {
79 max_abs / 127.0
80 } else {
81 1.0 };
83 }
84
85 let mut weights_q = vec![0i8; weights.len()];
87 for oc in 0..out_c {
88 let scale = weight_scales[oc];
89 for ic in 0..in_c {
90 for kh in 0..ks {
91 for kw in 0..ks {
92 let idx = oc * ks * ks * in_c + kh * ks * in_c + kw * in_c + ic;
93 let w_f32 = weights[idx];
94 let w_q = (w_f32 / scale).round().clamp(-127.0, 127.0) as i8;
95 weights_q[idx] = w_q;
96 }
97 }
98 }
99 }
100
101 let bias_f32 = conv.bias()
103 .map(|b| b.to_vec())
104 .unwrap_or_else(|| vec![0.0; out_c]);
105 let mut bias_q = vec![0i32; out_c];
106
107 for oc in 0..out_c {
108 let combined_scale = input_scale * weight_scales[oc];
110 bias_q[oc] = if combined_scale > 0.0 {
111 (bias_f32[oc] / combined_scale).round() as i32
112 } else {
113 0
114 };
115 }
116
117 Self {
118 weights_q,
119 weight_scales,
120 bias_q,
121 bias_f32,
122 in_channels: in_c,
123 out_channels: out_c,
124 kernel_size: ks,
125 stride: conv.stride(),
126 padding: conv.padding(),
127 groups: conv.groups(),
128 }
129 }
130
131 pub fn forward_int8(
141 &self,
142 input: &[u8],
143 input_shape: &[usize],
144 input_scale: f32,
145 input_zero_point: u8,
146 ) -> CnnResult<Tensor> {
147 if input_shape.len() != 4 {
148 return Err(CnnError::invalid_shape(
149 "4D input (NHWC)",
150 format!("{}D", input_shape.len())
151 ));
152 }
153
154 let batch = input_shape[0];
155 let in_h = input_shape[1];
156 let in_w = input_shape[2];
157 let in_c = input_shape[3];
158
159 if in_c != self.in_channels {
160 return Err(CnnError::invalid_shape(
161 format!("{} input channels", self.in_channels),
162 format!("{} channels", in_c)
163 ));
164 }
165
166 let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
167 let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
168
169 let mut output_i32 = vec![0i32; batch * out_h * out_w * self.out_channels];
170
171 for b in 0..batch {
173 let batch_in_size = in_h * in_w * in_c;
174 let batch_out_size = out_h * out_w * self.out_channels;
175
176 let input_slice = &input[b * batch_in_size..(b + 1) * batch_in_size];
177 let output_slice = &mut output_i32[b * batch_out_size..(b + 1) * batch_out_size];
178
179 #[cfg(target_arch = "x86_64")]
181 {
182 if is_x86_feature_detected!("avx2") {
183 unsafe {
184 self.conv_3x3_int8_avx2(
185 input_slice,
186 input_zero_point as i32,
187 output_slice,
188 in_h, in_w, out_h, out_w,
189 );
190 }
191 } else {
192 self.conv_3x3_int8_scalar(
193 input_slice,
194 input_zero_point as i32,
195 output_slice,
196 in_h, in_w, out_h, out_w,
197 );
198 }
199 }
200
201 #[cfg(not(target_arch = "x86_64"))]
202 {
203 self.conv_3x3_int8_scalar(
204 input_slice,
205 input_zero_point as i32,
206 output_slice,
207 in_h, in_w, out_h, out_w,
208 );
209 }
210 }
211
212 let output_f32 = self.dequantize_output(&output_i32, input_scale);
214
215 Tensor::from_data(
216 output_f32,
217 &[batch, out_h, out_w, self.out_channels],
218 )
219 }
220
221 fn conv_3x3_int8_scalar(
223 &self,
224 input: &[u8],
225 input_zero_point: i32,
226 output: &mut [i32],
227 in_h: usize,
228 in_w: usize,
229 out_h: usize,
230 out_w: usize,
231 ) {
232 let ks = self.kernel_size;
233
234 let mut weight_sums = vec![0i32; self.out_channels];
236 for oc in 0..self.out_channels {
237 let mut sum = 0i32;
238 for ic in 0..self.in_channels {
239 for kh in 0..ks {
240 for kw in 0..ks {
241 let idx = (oc * self.in_channels + ic) * ks * ks + kh * ks + kw;
242 sum += self.weights_q[idx] as i32;
243 }
244 }
245 }
246 weight_sums[oc] = sum;
247 }
248
249 for oh in 0..out_h {
250 for ow in 0..out_w {
251 for oc in 0..self.out_channels {
252 let mut acc = self.bias_q[oc] - input_zero_point * weight_sums[oc];
254
255 for kh in 0..ks {
257 for kw in 0..ks {
258 let ih = (oh * self.stride + kh) as isize - self.padding as isize;
259 let iw = (ow * self.stride + kw) as isize - self.padding as isize;
260
261 if ih >= 0 && ih < in_h as isize && iw >= 0 && iw < in_w as isize {
262 let ih = ih as usize;
263 let iw = iw as usize;
264
265 for ic in 0..self.in_channels {
266 let input_idx = (ih * in_w + iw) * self.in_channels + ic;
267 let weight_idx = (oc * self.in_channels + ic) * ks * ks + kh * ks + kw;
268
269 acc += (input[input_idx] as i32) * (self.weights_q[weight_idx] as i32);
270 }
271 }
272 }
273 }
274
275 output[(oh * out_w + ow) * self.out_channels + oc] = acc;
276 }
277 }
278 }
279 }
280
281 #[cfg(target_arch = "x86_64")]
283 #[target_feature(enable = "avx2")]
284 unsafe fn conv_3x3_int8_avx2(
285 &self,
286 input: &[u8],
287 input_zero_point: i32,
288 output: &mut [i32],
289 in_h: usize,
290 in_w: usize,
291 out_h: usize,
292 out_w: usize,
293 ) {
294 self.conv_3x3_int8_scalar(input, input_zero_point, output, in_h, in_w, out_h, out_w);
297 }
298
299 fn dequantize_output(&self, acc: &[i32], input_scale: f32) -> Vec<f32> {
301 let mut output = vec![0.0f32; acc.len()];
302
303 for (i, &val) in acc.iter().enumerate() {
304 let oc = i % self.out_channels;
305 let scale = input_scale * self.weight_scales[oc];
306 output[i] = val as f32 * scale;
307 }
308
309 output
310 }
311
312 pub fn out_channels(&self) -> usize {
314 self.out_channels
315 }
316
317 pub fn in_channels(&self) -> usize {
319 self.in_channels
320 }
321
322 pub fn kernel_size(&self) -> usize {
324 self.kernel_size
325 }
326
327 pub fn stride(&self) -> usize {
329 self.stride
330 }
331
332 pub fn padding(&self) -> usize {
334 self.padding
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use crate::layers::Conv2dBuilder;
342
343 #[test]
344 fn test_quantized_conv2d_creation() {
345 let conv = Conv2dBuilder::new(16, 32, 3)
346 .stride(1)
347 .padding(1)
348 .build()
349 .unwrap();
350
351 let qconv = QuantizedConv2d::from_fp32(&conv, 0.01, 128);
352
353 assert_eq!(qconv.in_channels(), 16);
354 assert_eq!(qconv.out_channels(), 32);
355 assert_eq!(qconv.kernel_size(), 3);
356 }
357
358 #[test]
359 fn test_quantized_conv2d_forward() {
360 let conv = Conv2dBuilder::new(3, 8, 3)
361 .stride(1)
362 .padding(1)
363 .build()
364 .unwrap();
365
366 let qconv = QuantizedConv2d::from_fp32(&conv, 0.01, 128);
367
368 let input = vec![128u8; 1 * 8 * 8 * 3]; let input_shape = &[1, 8, 8, 3];
371
372 let output = qconv.forward_int8(&input, input_shape, 0.01, 128).unwrap();
373
374 assert_eq!(output.shape(), &[1, 8, 8, 8]);
375 }
376}