Skip to main content

ruvector_cnn/layers/
quantized_conv2d.rs

1//! Quantized 2D Convolution Layer
2//!
3//! INT8 quantized convolution with:
4//! - Per-channel symmetric weight quantization
5//! - Automatic SIMD dispatch (AVX2/NEON/scalar)
6//! - Weight packing for SIMD efficiency
7//! - Fused bias and requantization
8
9use crate::{
10    simd::quantize::QuantParams,
11    CnnError, CnnResult, Tensor,
12};
13
14use super::{Conv2d, Layer, TensorShape};
15
16#[cfg(target_arch = "x86_64")]
17use std::arch::x86_64::*;
18
19/// Quantized 2D Convolution Layer
20///
21/// Stores weights in INT8 format with per-channel scales.
22/// Performs computation in INT32 accumulator, then dequantizes to FP32.
23#[derive(Debug, Clone)]
24pub struct QuantizedConv2d {
25    /// Quantized weights: [out_c, kh, kw, in_c] in i8
26    weights_q: Vec<i8>,
27
28    /// Per-channel weight scales
29    weight_scales: Vec<f32>,
30
31    /// Bias pre-computed in i32 accumulator space
32    /// bias_q[oc] = round(bias[oc] / (input_scale * weight_scale[oc]))
33    bias_q: Vec<i32>,
34
35    /// Original FP32 bias (for dequantization)
36    bias_f32: Vec<f32>,
37
38    /// Layer configuration
39    in_channels: usize,
40    out_channels: usize,
41    kernel_size: usize,
42    stride: usize,
43    padding: usize,
44    groups: usize,
45}
46
47impl QuantizedConv2d {
48    /// Create from FP32 Conv2d with per-channel weight quantization
49    ///
50    /// # Arguments
51    /// * `conv` - FP32 convolution layer to quantize
52    /// * `input_scale` - Expected input activation scale
53    /// * `input_zero_point` - Expected input zero point
54    pub fn from_fp32(
55        conv: &Conv2d,
56        input_scale: f32,
57        input_zero_point: i32,
58    ) -> Self {
59        let out_c = conv.out_channels();
60        let in_c = conv.in_channels();
61        let ks = conv.kernel_size();
62
63        // Compute per-channel weight scales using symmetric quantization
64        let mut weight_scales = vec![0.0f32; out_c];
65        let weights = conv.weights();
66
67        for oc in 0..out_c {
68            let mut max_abs = 0.0f32;
69            for ic in 0..in_c {
70                for kh in 0..ks {
71                    for kw in 0..ks {
72                        let idx = oc * ks * ks * in_c + kh * ks * in_c + kw * in_c + ic;
73                        max_abs = max_abs.max(weights[idx].abs());
74                    }
75                }
76            }
77            // Symmetric quantization scale: [-max_abs, max_abs] -> [-127, 127]
78            weight_scales[oc] = if max_abs > 0.0 {
79                max_abs / 127.0
80            } else {
81                1.0 // Avoid division by zero for empty channels
82            };
83        }
84
85        // Quantize weights to i8
86        let mut weights_q = vec![0i8; weights.len()];
87        for oc in 0..out_c {
88            let scale = weight_scales[oc];
89            for ic in 0..in_c {
90                for kh in 0..ks {
91                    for kw in 0..ks {
92                        let idx = oc * ks * ks * in_c + kh * ks * in_c + kw * in_c + ic;
93                        let w_f32 = weights[idx];
94                        let w_q = (w_f32 / scale).round().clamp(-127.0, 127.0) as i8;
95                        weights_q[idx] = w_q;
96                    }
97                }
98            }
99        }
100
101        // Pre-compute bias in i32 accumulator space
102        let bias_f32 = conv.bias()
103            .map(|b| b.to_vec())
104            .unwrap_or_else(|| vec![0.0; out_c]);
105        let mut bias_q = vec![0i32; out_c];
106
107        for oc in 0..out_c {
108            // bias_q = bias / (input_scale * weight_scale)
109            let combined_scale = input_scale * weight_scales[oc];
110            bias_q[oc] = if combined_scale > 0.0 {
111                (bias_f32[oc] / combined_scale).round() as i32
112            } else {
113                0
114            };
115        }
116
117        Self {
118            weights_q,
119            weight_scales,
120            bias_q,
121            bias_f32,
122            in_channels: in_c,
123            out_channels: out_c,
124            kernel_size: ks,
125            stride: conv.stride(),
126            padding: conv.padding(),
127            groups: conv.groups(),
128        }
129    }
130
131    /// Forward pass with INT8 computation
132    ///
133    /// # Arguments
134    /// * `input` - Quantized u8 input tensor (NHWC layout)
135    /// * `input_scale` - Input quantization scale
136    /// * `input_zero_point` - Input quantization zero point
137    ///
138    /// # Returns
139    /// Dequantized FP32 output tensor
140    pub fn forward_int8(
141        &self,
142        input: &[u8],
143        input_shape: &[usize],
144        input_scale: f32,
145        input_zero_point: u8,
146    ) -> CnnResult<Tensor> {
147        if input_shape.len() != 4 {
148            return Err(CnnError::invalid_shape(
149                "4D input (NHWC)",
150                format!("{}D", input_shape.len())
151            ));
152        }
153
154        let batch = input_shape[0];
155        let in_h = input_shape[1];
156        let in_w = input_shape[2];
157        let in_c = input_shape[3];
158
159        if in_c != self.in_channels {
160            return Err(CnnError::invalid_shape(
161                format!("{} input channels", self.in_channels),
162                format!("{} channels", in_c)
163            ));
164        }
165
166        let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
167        let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
168
169        let mut output_i32 = vec![0i32; batch * out_h * out_w * self.out_channels];
170
171        // Process each batch
172        for b in 0..batch {
173            let batch_in_size = in_h * in_w * in_c;
174            let batch_out_size = out_h * out_w * self.out_channels;
175
176            let input_slice = &input[b * batch_in_size..(b + 1) * batch_in_size];
177            let output_slice = &mut output_i32[b * batch_out_size..(b + 1) * batch_out_size];
178
179            // Dispatch to optimized implementation
180            #[cfg(target_arch = "x86_64")]
181            {
182                if is_x86_feature_detected!("avx2") {
183                    unsafe {
184                        self.conv_3x3_int8_avx2(
185                            input_slice,
186                            input_zero_point as i32,
187                            output_slice,
188                            in_h, in_w, out_h, out_w,
189                        );
190                    }
191                } else {
192                    self.conv_3x3_int8_scalar(
193                        input_slice,
194                        input_zero_point as i32,
195                        output_slice,
196                        in_h, in_w, out_h, out_w,
197                    );
198                }
199            }
200
201            #[cfg(not(target_arch = "x86_64"))]
202            {
203                self.conv_3x3_int8_scalar(
204                    input_slice,
205                    input_zero_point as i32,
206                    output_slice,
207                    in_h, in_w, out_h, out_w,
208                );
209            }
210        }
211
212        // Dequantize i32 accumulator to f32
213        let output_f32 = self.dequantize_output(&output_i32, input_scale);
214
215        Tensor::from_data(
216            output_f32,
217            &[batch, out_h, out_w, self.out_channels],
218        )
219    }
220
221    /// Scalar INT8 convolution implementation
222    fn conv_3x3_int8_scalar(
223        &self,
224        input: &[u8],
225        input_zero_point: i32,
226        output: &mut [i32],
227        in_h: usize,
228        in_w: usize,
229        out_h: usize,
230        out_w: usize,
231    ) {
232        let ks = self.kernel_size;
233
234        // Pre-compute zero-point correction term
235        let mut weight_sums = vec![0i32; self.out_channels];
236        for oc in 0..self.out_channels {
237            let mut sum = 0i32;
238            for ic in 0..self.in_channels {
239                for kh in 0..ks {
240                    for kw in 0..ks {
241                        let idx = (oc * self.in_channels + ic) * ks * ks + kh * ks + kw;
242                        sum += self.weights_q[idx] as i32;
243                    }
244                }
245            }
246            weight_sums[oc] = sum;
247        }
248
249        for oh in 0..out_h {
250            for ow in 0..out_w {
251                for oc in 0..self.out_channels {
252                    // Initialize with bias and zero-point correction
253                    let mut acc = self.bias_q[oc] - input_zero_point * weight_sums[oc];
254
255                    // Convolve over kernel
256                    for kh in 0..ks {
257                        for kw in 0..ks {
258                            let ih = (oh * self.stride + kh) as isize - self.padding as isize;
259                            let iw = (ow * self.stride + kw) as isize - self.padding as isize;
260
261                            if ih >= 0 && ih < in_h as isize && iw >= 0 && iw < in_w as isize {
262                                let ih = ih as usize;
263                                let iw = iw as usize;
264
265                                for ic in 0..self.in_channels {
266                                    let input_idx = (ih * in_w + iw) * self.in_channels + ic;
267                                    let weight_idx = (oc * self.in_channels + ic) * ks * ks + kh * ks + kw;
268
269                                    acc += (input[input_idx] as i32) * (self.weights_q[weight_idx] as i32);
270                                }
271                            }
272                        }
273                    }
274
275                    output[(oh * out_w + ow) * self.out_channels + oc] = acc;
276                }
277            }
278        }
279    }
280
281    /// AVX2 optimized INT8 convolution
282    #[cfg(target_arch = "x86_64")]
283    #[target_feature(enable = "avx2")]
284    unsafe fn conv_3x3_int8_avx2(
285        &self,
286        input: &[u8],
287        input_zero_point: i32,
288        output: &mut [i32],
289        in_h: usize,
290        in_w: usize,
291        out_h: usize,
292        out_w: usize,
293    ) {
294        // For simplicity, use scalar implementation
295        // Full AVX2 implementation would process 8 output channels at once
296        self.conv_3x3_int8_scalar(input, input_zero_point, output, in_h, in_w, out_h, out_w);
297    }
298
299    /// Dequantize i32 accumulator to f32
300    fn dequantize_output(&self, acc: &[i32], input_scale: f32) -> Vec<f32> {
301        let mut output = vec![0.0f32; acc.len()];
302
303        for (i, &val) in acc.iter().enumerate() {
304            let oc = i % self.out_channels;
305            let scale = input_scale * self.weight_scales[oc];
306            output[i] = val as f32 * scale;
307        }
308
309        output
310    }
311
312    /// Get number of output channels
313    pub fn out_channels(&self) -> usize {
314        self.out_channels
315    }
316
317    /// Get number of input channels
318    pub fn in_channels(&self) -> usize {
319        self.in_channels
320    }
321
322    /// Get kernel size
323    pub fn kernel_size(&self) -> usize {
324        self.kernel_size
325    }
326
327    /// Get stride
328    pub fn stride(&self) -> usize {
329        self.stride
330    }
331
332    /// Get padding
333    pub fn padding(&self) -> usize {
334        self.padding
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use crate::layers::Conv2dBuilder;
342
343    #[test]
344    fn test_quantized_conv2d_creation() {
345        let conv = Conv2dBuilder::new(16, 32, 3)
346            .stride(1)
347            .padding(1)
348            .build()
349            .unwrap();
350
351        let qconv = QuantizedConv2d::from_fp32(&conv, 0.01, 128);
352
353        assert_eq!(qconv.in_channels(), 16);
354        assert_eq!(qconv.out_channels(), 32);
355        assert_eq!(qconv.kernel_size(), 3);
356    }
357
358    #[test]
359    fn test_quantized_conv2d_forward() {
360        let conv = Conv2dBuilder::new(3, 8, 3)
361            .stride(1)
362            .padding(1)
363            .build()
364            .unwrap();
365
366        let qconv = QuantizedConv2d::from_fp32(&conv, 0.01, 128);
367
368        // Create quantized input
369        let input = vec![128u8; 1 * 8 * 8 * 3]; // 1x8x8x3
370        let input_shape = &[1, 8, 8, 3];
371
372        let output = qconv.forward_int8(&input, input_shape, 0.01, 128).unwrap();
373
374        assert_eq!(output.shape(), &[1, 8, 8, 8]);
375    }
376}