Skip to main content

ruvector_cnn/layers/
pooling.rs

1//! Pooling Layers
2//!
3//! SIMD-optimized pooling operations:
4//! - GlobalAvgPool: Global average pooling over spatial dimensions
5//! - MaxPool2d: Max pooling with configurable kernel and stride
6//! - AvgPool2d: Average pooling with configurable kernel and stride
7
8use crate::{simd, CnnError, CnnResult, Tensor};
9
10use super::Layer;
11
12/// Alias for GlobalAvgPool (for API compatibility)
13pub type GlobalAvgPool2d = GlobalAvgPool;
14
15/// Global Average Pooling
16///
17/// Reduces spatial dimensions to 1x1 by averaging over all spatial positions.
18/// Commonly used before the final fully-connected layer in CNNs.
19///
20/// Input: [batch, height, width, channels]
21/// Output: [batch, 1, 1, channels]
22#[derive(Debug, Clone, Default)]
23pub struct GlobalAvgPool;
24
25impl GlobalAvgPool {
26    /// Create a new global average pooling layer
27    pub fn new() -> Self {
28        Self
29    }
30}
31
32impl Layer for GlobalAvgPool {
33    fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
34        let shape = input.shape();
35        if shape.len() != 4 {
36            return Err(CnnError::invalid_shape(
37                "4D tensor (NHWC)",
38                format!("{}D tensor", shape.len()),
39            ));
40        }
41
42        let batch = shape[0];
43        let h = shape[1];
44        let w = shape[2];
45        let c = shape[3];
46
47        let out_shape = vec![batch, 1, 1, c];
48        let mut output = Tensor::zeros(&out_shape);
49
50        let batch_in_size = h * w * c;
51
52        for b in 0..batch {
53            let input_slice = &input.data()[b * batch_in_size..(b + 1) * batch_in_size];
54            let output_slice = &mut output.data_mut()[b * c..(b + 1) * c];
55
56            simd::global_avg_pool_simd(input_slice, output_slice, h, w, c);
57        }
58
59        Ok(output)
60    }
61
62    fn name(&self) -> &'static str {
63        "GlobalAvgPool"
64    }
65}
66
67/// 2D Max Pooling
68///
69/// Performs max pooling over spatial dimensions with configurable kernel size,
70/// stride, and padding.
71#[derive(Debug, Clone)]
72pub struct MaxPool2d {
73    /// Kernel size (height and width)
74    kernel_size: usize,
75    /// Stride
76    stride: usize,
77    /// Padding
78    padding: usize,
79}
80
81impl MaxPool2d {
82    /// Create a new max pooling layer
83    pub fn new(kernel_size: usize, stride: usize, padding: usize) -> Self {
84        Self {
85            kernel_size,
86            stride,
87            padding,
88        }
89    }
90
91    /// Create a max pooling layer with stride equal to kernel size
92    pub fn with_kernel(kernel_size: usize) -> Self {
93        Self::new(kernel_size, kernel_size, 0)
94    }
95
96    /// Get the output shape for a given input shape
97    pub fn output_shape(&self, input_shape: &[usize]) -> CnnResult<Vec<usize>> {
98        if input_shape.len() != 4 {
99            return Err(CnnError::invalid_shape(
100                "4D tensor (NHWC)",
101                format!("{}D tensor", input_shape.len()),
102            ));
103        }
104
105        let batch = input_shape[0];
106        let in_h = input_shape[1];
107        let in_w = input_shape[2];
108        let c = input_shape[3];
109
110        let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
111        let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
112
113        Ok(vec![batch, out_h, out_w, c])
114    }
115}
116
117impl Layer for MaxPool2d {
118    fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
119        let shape = input.shape();
120        if shape.len() != 4 {
121            return Err(CnnError::invalid_shape(
122                "4D tensor (NHWC)",
123                format!("{}D tensor", shape.len()),
124            ));
125        }
126
127        let batch = shape[0];
128        let h = shape[1];
129        let w = shape[2];
130        let c = shape[3];
131
132        let out_shape = self.output_shape(shape)?;
133        let out_h = out_shape[1];
134        let out_w = out_shape[2];
135
136        let mut output = Tensor::zeros(&out_shape);
137
138        let batch_in_size = h * w * c;
139        let batch_out_size = out_h * out_w * c;
140
141        for b in 0..batch {
142            let input_slice = &input.data()[b * batch_in_size..(b + 1) * batch_in_size];
143            let output_slice = &mut output.data_mut()[b * batch_out_size..(b + 1) * batch_out_size];
144
145            if self.kernel_size == 2 && self.padding == 0 {
146                simd::max_pool_2x2_simd(input_slice, output_slice, h, w, c, self.stride);
147            } else {
148                simd::scalar::max_pool_scalar(
149                    input_slice,
150                    output_slice,
151                    h,
152                    w,
153                    c,
154                    self.kernel_size,
155                    self.stride,
156                    self.padding,
157                );
158            }
159        }
160
161        Ok(output)
162    }
163
164    fn name(&self) -> &'static str {
165        "MaxPool2d"
166    }
167}
168
169/// 2D Average Pooling
170///
171/// Performs average pooling over spatial dimensions with configurable kernel size,
172/// stride, and padding.
173#[derive(Debug, Clone)]
174pub struct AvgPool2d {
175    /// Kernel size (height and width)
176    kernel_size: usize,
177    /// Stride
178    stride: usize,
179    /// Padding
180    padding: usize,
181}
182
183impl AvgPool2d {
184    /// Create a new average pooling layer
185    pub fn new(kernel_size: usize, stride: usize, padding: usize) -> Self {
186        Self {
187            kernel_size,
188            stride,
189            padding,
190        }
191    }
192
193    /// Create an average pooling layer with stride equal to kernel size
194    pub fn with_kernel(kernel_size: usize) -> Self {
195        Self::new(kernel_size, kernel_size, 0)
196    }
197
198    /// Get the output shape for a given input shape
199    pub fn output_shape(&self, input_shape: &[usize]) -> CnnResult<Vec<usize>> {
200        if input_shape.len() != 4 {
201            return Err(CnnError::invalid_shape(
202                "4D tensor (NHWC)",
203                format!("{}D tensor", input_shape.len()),
204            ));
205        }
206
207        let batch = input_shape[0];
208        let in_h = input_shape[1];
209        let in_w = input_shape[2];
210        let c = input_shape[3];
211
212        let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
213        let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
214
215        Ok(vec![batch, out_h, out_w, c])
216    }
217}
218
219impl Layer for AvgPool2d {
220    fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
221        let shape = input.shape();
222        if shape.len() != 4 {
223            return Err(CnnError::invalid_shape(
224                "4D tensor (NHWC)",
225                format!("{}D tensor", shape.len()),
226            ));
227        }
228
229        let batch = shape[0];
230        let h = shape[1];
231        let w = shape[2];
232        let c = shape[3];
233
234        let out_shape = self.output_shape(shape)?;
235        let out_h = out_shape[1];
236        let out_w = out_shape[2];
237
238        let mut output = Tensor::zeros(&out_shape);
239
240        let batch_in_size = h * w * c;
241        let batch_out_size = out_h * out_w * c;
242
243        for b in 0..batch {
244            let input_slice = &input.data()[b * batch_in_size..(b + 1) * batch_in_size];
245            let output_slice = &mut output.data_mut()[b * batch_out_size..(b + 1) * batch_out_size];
246
247            if self.kernel_size == 2 && self.padding == 0 {
248                simd::scalar::avg_pool_2x2_scalar(input_slice, output_slice, h, w, c, self.stride);
249            } else {
250                simd::scalar::avg_pool_scalar(
251                    input_slice,
252                    output_slice,
253                    h,
254                    w,
255                    c,
256                    self.kernel_size,
257                    self.stride,
258                    self.padding,
259                );
260            }
261        }
262
263        Ok(output)
264    }
265
266    fn name(&self) -> &'static str {
267        "AvgPool2d"
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_global_avg_pool() {
277        let pool = GlobalAvgPool::new();
278        let input = Tensor::ones(&[2, 4, 4, 8]);
279        let output = pool.forward(&input).unwrap();
280
281        assert_eq!(output.shape(), &[2, 1, 1, 8]);
282
283        // All ones averaged = 1
284        for &val in output.data() {
285            assert!((val - 1.0).abs() < 0.001);
286        }
287    }
288
289    #[test]
290    fn test_global_avg_pool_values() {
291        let pool = GlobalAvgPool::new();
292
293        // Create input with known values: channel 0 = 1, channel 1 = 2
294        let mut data = vec![0.0; 2 * 2 * 2];
295        for i in 0..4 {
296            data[i * 2] = 1.0;     // channel 0
297            data[i * 2 + 1] = 2.0; // channel 1
298        }
299        let input = Tensor::from_data(data, &[1, 2, 2, 2]).unwrap();
300
301        let output = pool.forward(&input).unwrap();
302
303        assert!((output.data()[0] - 1.0).abs() < 0.001);
304        assert!((output.data()[1] - 2.0).abs() < 0.001);
305    }
306
307    #[test]
308    fn test_max_pool2d() {
309        let pool = MaxPool2d::new(2, 2, 0);
310        let input = Tensor::ones(&[1, 8, 8, 4]);
311        let output = pool.forward(&input).unwrap();
312
313        assert_eq!(output.shape(), &[1, 4, 4, 4]);
314    }
315
316    #[test]
317    fn test_max_pool2d_values() {
318        let pool = MaxPool2d::new(2, 2, 0);
319
320        // 2x2 input, 1 channel: [[1, 2], [3, 4]]
321        let data = vec![1.0, 2.0, 3.0, 4.0];
322        let input = Tensor::from_data(data, &[1, 2, 2, 1]).unwrap();
323
324        let output = pool.forward(&input).unwrap();
325
326        assert_eq!(output.shape(), &[1, 1, 1, 1]);
327        assert_eq!(output.data()[0], 4.0);
328    }
329
330    #[test]
331    fn test_max_pool2d_output_shape() {
332        let pool = MaxPool2d::new(2, 2, 0);
333        let shape = pool.output_shape(&[1, 224, 224, 64]).unwrap();
334        assert_eq!(shape, vec![1, 112, 112, 64]);
335    }
336
337    #[test]
338    fn test_avg_pool2d() {
339        let pool = AvgPool2d::new(2, 2, 0);
340        let input = Tensor::ones(&[1, 8, 8, 4]);
341        let output = pool.forward(&input).unwrap();
342
343        assert_eq!(output.shape(), &[1, 4, 4, 4]);
344    }
345
346    #[test]
347    fn test_avg_pool2d_values() {
348        let pool = AvgPool2d::new(2, 2, 0);
349
350        // 2x2 input, 1 channel: [[1, 2], [3, 4]]
351        let data = vec![1.0, 2.0, 3.0, 4.0];
352        let input = Tensor::from_data(data, &[1, 2, 2, 1]).unwrap();
353
354        let output = pool.forward(&input).unwrap();
355
356        assert_eq!(output.shape(), &[1, 1, 1, 1]);
357        assert!((output.data()[0] - 2.5).abs() < 0.001); // (1+2+3+4)/4 = 2.5
358    }
359
360    #[test]
361    fn test_max_pool_with_stride1() {
362        let pool = MaxPool2d::new(2, 1, 0);
363        let shape = pool.output_shape(&[1, 4, 4, 1]).unwrap();
364        assert_eq!(shape, vec![1, 3, 3, 1]);
365    }
366}