Skip to main content

ruvector_cnn/layers/
quantized_depthwise.rs

1//! Quantized Depthwise Convolution Layer
2//!
3//! INT8 quantized depthwise convolution with:
4//! - Separate kernel for channel-wise operations
5//! - Efficient memory layout (channel-first processing)
6//! - Per-channel quantization
7
8use crate::{CnnError, CnnResult, Tensor};
9
10use super::{Conv2d, Layer};
11
12/// Quantized Depthwise Convolution Layer
13///
14/// Performs depthwise separable convolution in INT8:
15/// - Each input channel is convolved with a single kernel
16/// - No cross-channel mixing (unlike standard convolution)
17/// - Memory efficient for mobile architectures
18#[derive(Debug, Clone)]
19pub struct QuantizedDepthwiseConv2d {
20    /// Quantized weights: [channels, kh, kw] in i8
21    weights_q: Vec<i8>,
22
23    /// Per-channel weight scales
24    weight_scales: Vec<f32>,
25
26    /// Bias in i32 accumulator space
27    bias_q: Vec<i32>,
28
29    /// Original FP32 bias
30    bias_f32: Vec<f32>,
31
32    /// Layer configuration
33    channels: usize,
34    kernel_size: usize,
35    stride: usize,
36    padding: usize,
37}
38
39impl QuantizedDepthwiseConv2d {
40    /// Create from FP32 depthwise convolution
41    ///
42    /// # Arguments
43    /// * `channels` - Number of input/output channels
44    /// * `kernel_size` - Kernel size (assumed square)
45    /// * `weights` - FP32 weights [channels, kh, kw]
46    /// * `bias` - Optional FP32 bias [channels]
47    /// * `stride` - Stride
48    /// * `padding` - Padding
49    /// * `input_scale` - Expected input activation scale
50    pub fn from_fp32(
51        channels: usize,
52        kernel_size: usize,
53        weights: &[f32],
54        bias: Option<&[f32]>,
55        stride: usize,
56        padding: usize,
57        input_scale: f32,
58    ) -> Self {
59        // Compute per-channel weight scales
60        let mut weight_scales = vec![0.0f32; channels];
61
62        for c in 0..channels {
63            let mut max_abs = 0.0f32;
64            for kh in 0..kernel_size {
65                for kw in 0..kernel_size {
66                    let idx = c * kernel_size * kernel_size + kh * kernel_size + kw;
67                    max_abs = max_abs.max(weights[idx].abs());
68                }
69            }
70            weight_scales[c] = if max_abs > 0.0 {
71                max_abs / 127.0
72            } else {
73                1.0
74            };
75        }
76
77        // Quantize weights
78        let mut weights_q = vec![0i8; weights.len()];
79        for c in 0..channels {
80            let scale = weight_scales[c];
81            for kh in 0..kernel_size {
82                for kw in 0..kernel_size {
83                    let idx = c * kernel_size * kernel_size + kh * kernel_size + kw;
84                    let w_q = (weights[idx] / scale).round().clamp(-127.0, 127.0) as i8;
85                    weights_q[idx] = w_q;
86                }
87            }
88        }
89
90        // Pre-compute bias in i32 accumulator space
91        let bias_f32 = bias.map(|b| b.to_vec()).unwrap_or_else(|| vec![0.0; channels]);
92        let mut bias_q = vec![0i32; channels];
93
94        for c in 0..channels {
95            let combined_scale = input_scale * weight_scales[c];
96            bias_q[c] = if combined_scale > 0.0 {
97                (bias_f32[c] / combined_scale).round() as i32
98            } else {
99                0
100            };
101        }
102
103        Self {
104            weights_q,
105            weight_scales,
106            bias_q,
107            bias_f32,
108            channels,
109            kernel_size,
110            stride,
111            padding,
112        }
113    }
114
115    /// Forward pass with INT8 computation
116    ///
117    /// # Arguments
118    /// * `input` - Quantized u8 input tensor (NHWC layout)
119    /// * `input_shape` - Input shape [N, H, W, C]
120    /// * `input_scale` - Input quantization scale
121    /// * `input_zero_point` - Input quantization zero point
122    pub fn forward_int8(
123        &self,
124        input: &[u8],
125        input_shape: &[usize],
126        input_scale: f32,
127        input_zero_point: u8,
128    ) -> CnnResult<Tensor> {
129        if input_shape.len() != 4 {
130            return Err(CnnError::invalid_shape(
131                "4D input (NHWC)",
132                format!("{}D", input_shape.len())
133            ));
134        }
135
136        let batch = input_shape[0];
137        let in_h = input_shape[1];
138        let in_w = input_shape[2];
139        let in_c = input_shape[3];
140
141        if in_c != self.channels {
142            return Err(CnnError::invalid_shape(
143                format!("{} channels", self.channels),
144                format!("{} channels", in_c)
145            ));
146        }
147
148        let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
149        let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
150
151        let mut output_i32 = vec![0i32; batch * out_h * out_w * self.channels];
152
153        // Process each batch
154        for b in 0..batch {
155            let batch_in_size = in_h * in_w * in_c;
156            let batch_out_size = out_h * out_w * self.channels;
157
158            let input_slice = &input[b * batch_in_size..(b + 1) * batch_in_size];
159            let output_slice = &mut output_i32[b * batch_out_size..(b + 1) * batch_out_size];
160
161            self.depthwise_conv_int8_scalar(
162                input_slice,
163                input_zero_point as i32,
164                output_slice,
165                in_h, in_w, out_h, out_w,
166            );
167        }
168
169        // Dequantize to f32
170        let output_f32 = self.dequantize_output(&output_i32, input_scale);
171
172        Tensor::from_data(
173            output_f32,
174            &[batch, out_h, out_w, self.channels],
175        )
176    }
177
178    /// Scalar depthwise convolution implementation
179    fn depthwise_conv_int8_scalar(
180        &self,
181        input: &[u8],
182        input_zero_point: i32,
183        output: &mut [i32],
184        in_h: usize,
185        in_w: usize,
186        out_h: usize,
187        out_w: usize,
188    ) {
189        let ks = self.kernel_size;
190
191        // Pre-compute weight sums per channel
192        let mut weight_sums = vec![0i32; self.channels];
193        for c in 0..self.channels {
194            let mut sum = 0i32;
195            for kh in 0..ks {
196                for kw in 0..ks {
197                    let idx = c * ks * ks + kh * ks + kw;
198                    sum += self.weights_q[idx] as i32;
199                }
200            }
201            weight_sums[c] = sum;
202        }
203
204        // Depthwise convolution: each channel processed independently
205        for oh in 0..out_h {
206            for ow in 0..out_w {
207                for c in 0..self.channels {
208                    // Initialize with bias and zero-point correction
209                    let mut acc = self.bias_q[c] - input_zero_point * weight_sums[c];
210
211                    // Convolve over kernel
212                    for kh in 0..ks {
213                        for kw in 0..ks {
214                            let ih = (oh * self.stride + kh) as isize - self.padding as isize;
215                            let iw = (ow * self.stride + kw) as isize - self.padding as isize;
216
217                            if ih >= 0 && ih < in_h as isize && iw >= 0 && iw < in_w as isize {
218                                let ih = ih as usize;
219                                let iw = iw as usize;
220
221                                let input_idx = (ih * in_w + iw) * self.channels + c;
222                                let weight_idx = c * ks * ks + kh * ks + kw;
223
224                                acc += (input[input_idx] as i32) * (self.weights_q[weight_idx] as i32);
225                            }
226                        }
227                    }
228
229                    output[(oh * out_w + ow) * self.channels + c] = acc;
230                }
231            }
232        }
233    }
234
235    /// Dequantize i32 accumulator to f32
236    fn dequantize_output(&self, acc: &[i32], input_scale: f32) -> Vec<f32> {
237        let mut output = vec![0.0f32; acc.len()];
238
239        for (i, &val) in acc.iter().enumerate() {
240            let c = i % self.channels;
241            let scale = input_scale * self.weight_scales[c];
242            output[i] = val as f32 * scale;
243        }
244
245        output
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_quantized_depthwise_conv2d_creation() {
255        let channels = 32;
256        let kernel_size = 3;
257        let weights = vec![0.1f32; channels * kernel_size * kernel_size];
258        let bias_vec = vec![0.0f32; channels];
259
260        let qconv = QuantizedDepthwiseConv2d::from_fp32(
261            channels,
262            kernel_size,
263            &weights,
264            Some(&bias_vec),
265            1,
266            1,
267            0.01,
268        );
269
270        assert_eq!(qconv.channels, 32);
271        assert_eq!(qconv.kernel_size, 3);
272    }
273
274    #[test]
275    fn test_quantized_depthwise_conv2d_forward() {
276        let channels = 16;
277        let kernel_size = 3;
278        let weights = vec![0.1f32; channels * kernel_size * kernel_size];
279
280        let qconv = QuantizedDepthwiseConv2d::from_fp32(
281            channels,
282            kernel_size,
283            &weights,
284            None,
285            1,
286            1,
287            0.01,
288        );
289
290        let input = vec![128u8; 1 * 8 * 8 * channels];
291        let input_shape = &[1, 8, 8, channels];
292
293        let output = qconv.forward_int8(&input, input_shape, 0.01, 128).unwrap();
294
295        assert_eq!(output.shape(), &[1, 8, 8, channels]);
296    }
297}