1use crate::{CnnError, CnnResult, Tensor};
8
9#[derive(Debug, Clone)]
14pub struct QuantizedMaxPool2d {
15 kernel_size: usize,
16 stride: usize,
17 padding: usize,
18}
19
20impl QuantizedMaxPool2d {
21 pub fn new(kernel_size: usize, stride: usize, padding: usize) -> Self {
23 Self {
24 kernel_size,
25 stride,
26 padding,
27 }
28 }
29
30 pub fn forward_int8(
38 &self,
39 input: &[u8],
40 input_shape: &[usize],
41 scale: f32,
42 zero_point: u8,
43 ) -> CnnResult<(Vec<u8>, Vec<usize>, f32, u8)> {
44 if input_shape.len() != 4 {
45 return Err(CnnError::invalid_shape(
46 "4D input (NHWC)",
47 format!("{}D", input_shape.len())
48 ));
49 }
50
51 let batch = input_shape[0];
52 let in_h = input_shape[1];
53 let in_w = input_shape[2];
54 let channels = input_shape[3];
55
56 let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
57 let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
58
59 let mut output = vec![zero_point; batch * out_h * out_w * channels];
60
61 for b in 0..batch {
62 for oh in 0..out_h {
63 for ow in 0..out_w {
64 for c in 0..channels {
65 let mut max_val = zero_point;
66
67 for kh in 0..self.kernel_size {
68 for kw in 0..self.kernel_size {
69 let ih = (oh * self.stride + kh) as isize - self.padding as isize;
70 let iw = (ow * self.stride + kw) as isize - self.padding as isize;
71
72 if ih >= 0 && ih < in_h as isize && iw >= 0 && iw < in_w as isize {
73 let ih = ih as usize;
74 let iw = iw as usize;
75 let input_idx = ((b * in_h + ih) * in_w + iw) * channels + c;
76 max_val = max_val.max(input[input_idx]);
77 }
78 }
79 }
80
81 let output_idx = ((b * out_h + oh) * out_w + ow) * channels + c;
82 output[output_idx] = max_val;
83 }
84 }
85 }
86 }
87
88 Ok((output, vec![batch, out_h, out_w, channels], scale, zero_point))
89 }
90}
91
92#[derive(Debug, Clone)]
97pub struct QuantizedAvgPool2d {
98 kernel_size: usize,
99 stride: usize,
100 padding: usize,
101}
102
103impl QuantizedAvgPool2d {
104 pub fn new(kernel_size: usize, stride: usize, padding: usize) -> Self {
106 Self {
107 kernel_size,
108 stride,
109 padding,
110 }
111 }
112
113 pub fn forward_int8(
124 &self,
125 input: &[u8],
126 input_shape: &[usize],
127 input_scale: f32,
128 input_zero_point: u8,
129 ) -> CnnResult<(Vec<u8>, Vec<usize>, f32, u8)> {
130 if input_shape.len() != 4 {
131 return Err(CnnError::invalid_shape(
132 "4D input (NHWC)",
133 format!("{}D", input_shape.len())
134 ));
135 }
136
137 let batch = input_shape[0];
138 let in_h = input_shape[1];
139 let in_w = input_shape[2];
140 let channels = input_shape[3];
141
142 let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
143 let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
144
145 let mut output_i16 = vec![0i16; batch * out_h * out_w * channels];
147
148 let kernel_area = self.kernel_size * self.kernel_size;
149
150 for b in 0..batch {
151 for oh in 0..out_h {
152 for ow in 0..out_w {
153 for c in 0..channels {
154 let mut sum = 0i16;
155 let mut count = 0;
156
157 for kh in 0..self.kernel_size {
158 for kw in 0..self.kernel_size {
159 let ih = (oh * self.stride + kh) as isize - self.padding as isize;
160 let iw = (ow * self.stride + kw) as isize - self.padding as isize;
161
162 if ih >= 0 && ih < in_h as isize && iw >= 0 && iw < in_w as isize {
163 let ih = ih as usize;
164 let iw = iw as usize;
165 let input_idx = ((b * in_h + ih) * in_w + iw) * channels + c;
166 sum += input[input_idx] as i16;
167 count += 1;
168 }
169 }
170 }
171
172 let avg = if count > 0 {
174 (sum + count / 2) / count } else {
176 input_zero_point as i16
177 };
178
179 let output_idx = ((b * out_h + oh) * out_w + ow) * channels + c;
180 output_i16[output_idx] = avg;
181 }
182 }
183 }
184 }
185
186 let output: Vec<u8> = output_i16.iter()
188 .map(|&v| v.clamp(0, 255) as u8)
189 .collect();
190
191 Ok((output, vec![batch, out_h, out_w, channels], input_scale, input_zero_point))
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 #[test]
201 fn test_quantized_maxpool2d() {
202 let pool = QuantizedMaxPool2d::new(2, 2, 0);
203
204 let input = vec![
205 100, 150, 200, 255,
206 120, 180, 210, 230,
207 110, 140, 190, 240,
208 130, 160, 220, 250,
209 ];
210 let input_shape = &[1, 4, 4, 1];
211
212 let (output, output_shape, scale, _zp) = pool.forward_int8(&input, input_shape, 0.01, 0).unwrap();
213
214 assert_eq!(output_shape, vec![1, 2, 2, 1]);
215 assert_eq!(scale, 0.01);
216
217 assert!(output[0] >= 100);
219 }
220
221 #[test]
222 fn test_quantized_avgpool2d() {
223 let pool = QuantizedAvgPool2d::new(2, 2, 0);
224
225 let input = vec![
226 100, 100, 200, 200,
227 100, 100, 200, 200,
228 100, 100, 200, 200,
229 100, 100, 200, 200,
230 ];
231 let input_shape = &[1, 4, 4, 1];
232
233 let (output, output_shape, scale, _zp) = pool.forward_int8(&input, input_shape, 0.01, 0).unwrap();
234
235 assert_eq!(output_shape, vec![1, 2, 2, 1]);
236 assert_eq!(scale, 0.01);
237
238 assert!(output[0] >= 95 && output[0] <= 105); assert!(output[1] >= 195 && output[1] <= 205); }
242
243 #[test]
244 fn test_quantized_maxpool2d_with_padding() {
245 let pool = QuantizedMaxPool2d::new(3, 1, 1);
246
247 let input = vec![100u8; 1 * 4 * 4 * 1];
248 let input_shape = &[1, 4, 4, 1];
249
250 let (_output, output_shape, _, _) = pool.forward_int8(&input, input_shape, 0.01, 50).unwrap();
251
252 assert_eq!(output_shape, vec![1, 4, 4, 1]);
253 }
254}