1use crate::{simd, CnnError, CnnResult, Tensor};
9
10use super::Layer;
11
12pub type GlobalAvgPool2d = GlobalAvgPool;
14
15#[derive(Debug, Clone, Default)]
23pub struct GlobalAvgPool;
24
25impl GlobalAvgPool {
26 pub fn new() -> Self {
28 Self
29 }
30}
31
32impl Layer for GlobalAvgPool {
33 fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
34 let shape = input.shape();
35 if shape.len() != 4 {
36 return Err(CnnError::invalid_shape(
37 "4D tensor (NHWC)",
38 format!("{}D tensor", shape.len()),
39 ));
40 }
41
42 let batch = shape[0];
43 let h = shape[1];
44 let w = shape[2];
45 let c = shape[3];
46
47 let out_shape = vec![batch, 1, 1, c];
48 let mut output = Tensor::zeros(&out_shape);
49
50 let batch_in_size = h * w * c;
51
52 for b in 0..batch {
53 let input_slice = &input.data()[b * batch_in_size..(b + 1) * batch_in_size];
54 let output_slice = &mut output.data_mut()[b * c..(b + 1) * c];
55
56 simd::global_avg_pool_simd(input_slice, output_slice, h, w, c);
57 }
58
59 Ok(output)
60 }
61
62 fn name(&self) -> &'static str {
63 "GlobalAvgPool"
64 }
65}
66
67#[derive(Debug, Clone)]
72pub struct MaxPool2d {
73 kernel_size: usize,
75 stride: usize,
77 padding: usize,
79}
80
81impl MaxPool2d {
82 pub fn new(kernel_size: usize, stride: usize, padding: usize) -> Self {
84 Self {
85 kernel_size,
86 stride,
87 padding,
88 }
89 }
90
91 pub fn with_kernel(kernel_size: usize) -> Self {
93 Self::new(kernel_size, kernel_size, 0)
94 }
95
96 pub fn output_shape(&self, input_shape: &[usize]) -> CnnResult<Vec<usize>> {
98 if input_shape.len() != 4 {
99 return Err(CnnError::invalid_shape(
100 "4D tensor (NHWC)",
101 format!("{}D tensor", input_shape.len()),
102 ));
103 }
104
105 let batch = input_shape[0];
106 let in_h = input_shape[1];
107 let in_w = input_shape[2];
108 let c = input_shape[3];
109
110 let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
111 let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
112
113 Ok(vec![batch, out_h, out_w, c])
114 }
115}
116
117impl Layer for MaxPool2d {
118 fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
119 let shape = input.shape();
120 if shape.len() != 4 {
121 return Err(CnnError::invalid_shape(
122 "4D tensor (NHWC)",
123 format!("{}D tensor", shape.len()),
124 ));
125 }
126
127 let batch = shape[0];
128 let h = shape[1];
129 let w = shape[2];
130 let c = shape[3];
131
132 let out_shape = self.output_shape(shape)?;
133 let out_h = out_shape[1];
134 let out_w = out_shape[2];
135
136 let mut output = Tensor::zeros(&out_shape);
137
138 let batch_in_size = h * w * c;
139 let batch_out_size = out_h * out_w * c;
140
141 for b in 0..batch {
142 let input_slice = &input.data()[b * batch_in_size..(b + 1) * batch_in_size];
143 let output_slice = &mut output.data_mut()[b * batch_out_size..(b + 1) * batch_out_size];
144
145 if self.kernel_size == 2 && self.padding == 0 {
146 simd::max_pool_2x2_simd(input_slice, output_slice, h, w, c, self.stride);
147 } else {
148 simd::scalar::max_pool_scalar(
149 input_slice,
150 output_slice,
151 h,
152 w,
153 c,
154 self.kernel_size,
155 self.stride,
156 self.padding,
157 );
158 }
159 }
160
161 Ok(output)
162 }
163
164 fn name(&self) -> &'static str {
165 "MaxPool2d"
166 }
167}
168
169#[derive(Debug, Clone)]
174pub struct AvgPool2d {
175 kernel_size: usize,
177 stride: usize,
179 padding: usize,
181}
182
183impl AvgPool2d {
184 pub fn new(kernel_size: usize, stride: usize, padding: usize) -> Self {
186 Self {
187 kernel_size,
188 stride,
189 padding,
190 }
191 }
192
193 pub fn with_kernel(kernel_size: usize) -> Self {
195 Self::new(kernel_size, kernel_size, 0)
196 }
197
198 pub fn output_shape(&self, input_shape: &[usize]) -> CnnResult<Vec<usize>> {
200 if input_shape.len() != 4 {
201 return Err(CnnError::invalid_shape(
202 "4D tensor (NHWC)",
203 format!("{}D tensor", input_shape.len()),
204 ));
205 }
206
207 let batch = input_shape[0];
208 let in_h = input_shape[1];
209 let in_w = input_shape[2];
210 let c = input_shape[3];
211
212 let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
213 let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
214
215 Ok(vec![batch, out_h, out_w, c])
216 }
217}
218
219impl Layer for AvgPool2d {
220 fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
221 let shape = input.shape();
222 if shape.len() != 4 {
223 return Err(CnnError::invalid_shape(
224 "4D tensor (NHWC)",
225 format!("{}D tensor", shape.len()),
226 ));
227 }
228
229 let batch = shape[0];
230 let h = shape[1];
231 let w = shape[2];
232 let c = shape[3];
233
234 let out_shape = self.output_shape(shape)?;
235 let out_h = out_shape[1];
236 let out_w = out_shape[2];
237
238 let mut output = Tensor::zeros(&out_shape);
239
240 let batch_in_size = h * w * c;
241 let batch_out_size = out_h * out_w * c;
242
243 for b in 0..batch {
244 let input_slice = &input.data()[b * batch_in_size..(b + 1) * batch_in_size];
245 let output_slice = &mut output.data_mut()[b * batch_out_size..(b + 1) * batch_out_size];
246
247 if self.kernel_size == 2 && self.padding == 0 {
248 simd::scalar::avg_pool_2x2_scalar(input_slice, output_slice, h, w, c, self.stride);
249 } else {
250 simd::scalar::avg_pool_scalar(
251 input_slice,
252 output_slice,
253 h,
254 w,
255 c,
256 self.kernel_size,
257 self.stride,
258 self.padding,
259 );
260 }
261 }
262
263 Ok(output)
264 }
265
266 fn name(&self) -> &'static str {
267 "AvgPool2d"
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_global_avg_pool() {
277 let pool = GlobalAvgPool::new();
278 let input = Tensor::ones(&[2, 4, 4, 8]);
279 let output = pool.forward(&input).unwrap();
280
281 assert_eq!(output.shape(), &[2, 1, 1, 8]);
282
283 for &val in output.data() {
285 assert!((val - 1.0).abs() < 0.001);
286 }
287 }
288
289 #[test]
290 fn test_global_avg_pool_values() {
291 let pool = GlobalAvgPool::new();
292
293 let mut data = vec![0.0; 2 * 2 * 2];
295 for i in 0..4 {
296 data[i * 2] = 1.0; data[i * 2 + 1] = 2.0; }
299 let input = Tensor::from_data(data, &[1, 2, 2, 2]).unwrap();
300
301 let output = pool.forward(&input).unwrap();
302
303 assert!((output.data()[0] - 1.0).abs() < 0.001);
304 assert!((output.data()[1] - 2.0).abs() < 0.001);
305 }
306
307 #[test]
308 fn test_max_pool2d() {
309 let pool = MaxPool2d::new(2, 2, 0);
310 let input = Tensor::ones(&[1, 8, 8, 4]);
311 let output = pool.forward(&input).unwrap();
312
313 assert_eq!(output.shape(), &[1, 4, 4, 4]);
314 }
315
316 #[test]
317 fn test_max_pool2d_values() {
318 let pool = MaxPool2d::new(2, 2, 0);
319
320 let data = vec![1.0, 2.0, 3.0, 4.0];
322 let input = Tensor::from_data(data, &[1, 2, 2, 1]).unwrap();
323
324 let output = pool.forward(&input).unwrap();
325
326 assert_eq!(output.shape(), &[1, 1, 1, 1]);
327 assert_eq!(output.data()[0], 4.0);
328 }
329
330 #[test]
331 fn test_max_pool2d_output_shape() {
332 let pool = MaxPool2d::new(2, 2, 0);
333 let shape = pool.output_shape(&[1, 224, 224, 64]).unwrap();
334 assert_eq!(shape, vec![1, 112, 112, 64]);
335 }
336
337 #[test]
338 fn test_avg_pool2d() {
339 let pool = AvgPool2d::new(2, 2, 0);
340 let input = Tensor::ones(&[1, 8, 8, 4]);
341 let output = pool.forward(&input).unwrap();
342
343 assert_eq!(output.shape(), &[1, 4, 4, 4]);
344 }
345
346 #[test]
347 fn test_avg_pool2d_values() {
348 let pool = AvgPool2d::new(2, 2, 0);
349
350 let data = vec![1.0, 2.0, 3.0, 4.0];
352 let input = Tensor::from_data(data, &[1, 2, 2, 1]).unwrap();
353
354 let output = pool.forward(&input).unwrap();
355
356 assert_eq!(output.shape(), &[1, 1, 1, 1]);
357 assert!((output.data()[0] - 2.5).abs() < 0.001); }
359
360 #[test]
361 fn test_max_pool_with_stride1() {
362 let pool = MaxPool2d::new(2, 1, 0);
363 let shape = pool.output_shape(&[1, 4, 4, 1]).unwrap();
364 assert_eq!(shape, vec![1, 3, 3, 1]);
365 }
366}