Skip to main content

ruvector_cnn/layers/
quantized_pooling.rs

1//! Quantized Pooling Layers
2//!
3//! INT8 quantized pooling operations:
4//! - QuantizedMaxPool2d: operates in INT8 domain (no scale change)
5//! - QuantizedAvgPool2d: uses i16 intermediate precision for accumulation
6
7use crate::{CnnError, CnnResult, Tensor};
8
9/// Quantized Max Pooling 2D
10///
11/// Operates directly in INT8 domain without scale changes.
12/// Output has the same scale and zero-point as input.
13#[derive(Debug, Clone)]
14pub struct QuantizedMaxPool2d {
15    kernel_size: usize,
16    stride: usize,
17    padding: usize,
18}
19
20impl QuantizedMaxPool2d {
21    /// Create a new quantized max pooling layer
22    pub fn new(kernel_size: usize, stride: usize, padding: usize) -> Self {
23        Self {
24            kernel_size,
25            stride,
26            padding,
27        }
28    }
29
30    /// Forward pass with INT8 input
31    ///
32    /// # Arguments
33    /// * `input` - Quantized u8 input tensor (NHWC layout)
34    /// * `input_shape` - Input shape [N, H, W, C]
35    /// * `scale` - Input/output scale (unchanged)
36    /// * `zero_point` - Input/output zero point (unchanged)
37    pub fn forward_int8(
38        &self,
39        input: &[u8],
40        input_shape: &[usize],
41        scale: f32,
42        zero_point: u8,
43    ) -> CnnResult<(Vec<u8>, Vec<usize>, f32, u8)> {
44        if input_shape.len() != 4 {
45            return Err(CnnError::invalid_shape(
46                "4D input (NHWC)",
47                format!("{}D", input_shape.len())
48            ));
49        }
50
51        let batch = input_shape[0];
52        let in_h = input_shape[1];
53        let in_w = input_shape[2];
54        let channels = input_shape[3];
55
56        let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
57        let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
58
59        let mut output = vec![zero_point; batch * out_h * out_w * channels];
60
61        for b in 0..batch {
62            for oh in 0..out_h {
63                for ow in 0..out_w {
64                    for c in 0..channels {
65                        let mut max_val = zero_point;
66
67                        for kh in 0..self.kernel_size {
68                            for kw in 0..self.kernel_size {
69                                let ih = (oh * self.stride + kh) as isize - self.padding as isize;
70                                let iw = (ow * self.stride + kw) as isize - self.padding as isize;
71
72                                if ih >= 0 && ih < in_h as isize && iw >= 0 && iw < in_w as isize {
73                                    let ih = ih as usize;
74                                    let iw = iw as usize;
75                                    let input_idx = ((b * in_h + ih) * in_w + iw) * channels + c;
76                                    max_val = max_val.max(input[input_idx]);
77                                }
78                            }
79                        }
80
81                        let output_idx = ((b * out_h + oh) * out_w + ow) * channels + c;
82                        output[output_idx] = max_val;
83                    }
84                }
85            }
86        }
87
88        Ok((output, vec![batch, out_h, out_w, channels], scale, zero_point))
89    }
90}
91
92/// Quantized Average Pooling 2D
93///
94/// Uses i16 intermediate precision to accumulate sums before division.
95/// Output may have different scale than input due to averaging.
96#[derive(Debug, Clone)]
97pub struct QuantizedAvgPool2d {
98    kernel_size: usize,
99    stride: usize,
100    padding: usize,
101}
102
103impl QuantizedAvgPool2d {
104    /// Create a new quantized average pooling layer
105    pub fn new(kernel_size: usize, stride: usize, padding: usize) -> Self {
106        Self {
107            kernel_size,
108            stride,
109            padding,
110        }
111    }
112
113    /// Forward pass with INT8 input
114    ///
115    /// # Arguments
116    /// * `input` - Quantized u8 input tensor (NHWC layout)
117    /// * `input_shape` - Input shape [N, H, W, C]
118    /// * `input_scale` - Input scale
119    /// * `input_zero_point` - Input zero point
120    ///
121    /// # Returns
122    /// (output, output_shape, output_scale, output_zero_point)
123    pub fn forward_int8(
124        &self,
125        input: &[u8],
126        input_shape: &[usize],
127        input_scale: f32,
128        input_zero_point: u8,
129    ) -> CnnResult<(Vec<u8>, Vec<usize>, f32, u8)> {
130        if input_shape.len() != 4 {
131            return Err(CnnError::invalid_shape(
132                "4D input (NHWC)",
133                format!("{}D", input_shape.len())
134            ));
135        }
136
137        let batch = input_shape[0];
138        let in_h = input_shape[1];
139        let in_w = input_shape[2];
140        let channels = input_shape[3];
141
142        let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
143        let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
144
145        // Use i16 for accumulation to avoid overflow
146        let mut output_i16 = vec![0i16; batch * out_h * out_w * channels];
147
148        let kernel_area = self.kernel_size * self.kernel_size;
149
150        for b in 0..batch {
151            for oh in 0..out_h {
152                for ow in 0..out_w {
153                    for c in 0..channels {
154                        let mut sum = 0i16;
155                        let mut count = 0;
156
157                        for kh in 0..self.kernel_size {
158                            for kw in 0..self.kernel_size {
159                                let ih = (oh * self.stride + kh) as isize - self.padding as isize;
160                                let iw = (ow * self.stride + kw) as isize - self.padding as isize;
161
162                                if ih >= 0 && ih < in_h as isize && iw >= 0 && iw < in_w as isize {
163                                    let ih = ih as usize;
164                                    let iw = iw as usize;
165                                    let input_idx = ((b * in_h + ih) * in_w + iw) * channels + c;
166                                    sum += input[input_idx] as i16;
167                                    count += 1;
168                                }
169                            }
170                        }
171
172                        // Compute average
173                        let avg = if count > 0 {
174                            (sum + count / 2) / count // Rounding division
175                        } else {
176                            input_zero_point as i16
177                        };
178
179                        let output_idx = ((b * out_h + oh) * out_w + ow) * channels + c;
180                        output_i16[output_idx] = avg;
181                    }
182                }
183            }
184        }
185
186        // Convert i16 back to u8
187        let output: Vec<u8> = output_i16.iter()
188            .map(|&v| v.clamp(0, 255) as u8)
189            .collect();
190
191        // Output scale remains the same as input for average pooling
192        Ok((output, vec![batch, out_h, out_w, channels], input_scale, input_zero_point))
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    #[test]
201    fn test_quantized_maxpool2d() {
202        let pool = QuantizedMaxPool2d::new(2, 2, 0);
203
204        let input = vec![
205            100, 150, 200, 255,
206            120, 180, 210, 230,
207            110, 140, 190, 240,
208            130, 160, 220, 250,
209        ];
210        let input_shape = &[1, 4, 4, 1];
211
212        let (output, output_shape, scale, _zp) = pool.forward_int8(&input, input_shape, 0.01, 0).unwrap();
213
214        assert_eq!(output_shape, vec![1, 2, 2, 1]);
215        assert_eq!(scale, 0.01);
216
217        // Check that max values are selected
218        assert!(output[0] >= 100);
219    }
220
221    #[test]
222    fn test_quantized_avgpool2d() {
223        let pool = QuantizedAvgPool2d::new(2, 2, 0);
224
225        let input = vec![
226            100, 100, 200, 200,
227            100, 100, 200, 200,
228            100, 100, 200, 200,
229            100, 100, 200, 200,
230        ];
231        let input_shape = &[1, 4, 4, 1];
232
233        let (output, output_shape, scale, _zp) = pool.forward_int8(&input, input_shape, 0.01, 0).unwrap();
234
235        assert_eq!(output_shape, vec![1, 2, 2, 1]);
236        assert_eq!(scale, 0.01);
237
238        // Check approximate averages
239        assert!(output[0] >= 95 && output[0] <= 105); // ~100
240        assert!(output[1] >= 195 && output[1] <= 205); // ~200
241    }
242
243    #[test]
244    fn test_quantized_maxpool2d_with_padding() {
245        let pool = QuantizedMaxPool2d::new(3, 1, 1);
246
247        let input = vec![100u8; 1 * 4 * 4 * 1];
248        let input_shape = &[1, 4, 4, 1];
249
250        let (_output, output_shape, _, _) = pool.forward_int8(&input, input_shape, 0.01, 50).unwrap();
251
252        assert_eq!(output_shape, vec![1, 4, 4, 1]);
253    }
254}