rustyml/neural_network/layer/
max_pooling_2d.rs

1use crate::neural_network::{ModelError, Tensor};
2use crate::traits::Layer;
3use ndarray::ArrayD;
4use rayon::prelude::*;
5
6/// Defines a structure for max pooling operation, used to perform max pooling on 2D data.
7///
8/// Max pooling is a common downsampling technique in CNNs that reduces
9/// the spatial dimensions of feature maps by selecting the maximum value in each pooling window,
10/// thereby reducing computation and controlling overfitting.
11///
12/// # Fields
13///
14/// - `pool_size` - Size of the pooling window, expressed as (height, width).
15/// - `strides` - Stride of the pooling operation, expressed as (vertical stride, horizontal stride).
16/// - `input_shape` - Shape of the input tensor.
17/// - `input_cache` - Cached input data, used for backpropagation.
18/// - `max_positions` - Cache of maximum value positions, used for backpropagation.
19///
20/// # Example
21/// ```rust
22/// use rustyml::prelude::*;
23/// use ndarray::Array4;
24///
25/// // Create a simple 4D input tensor: [batch_size, channels, height, width]
26/// // Batch size=2, 3 input channels, 6x6 pixels
27/// let mut input_data = Array4::zeros((2, 3, 6, 6));
28///
29/// // Set some specific values so we can predict the max pooling result
30/// for b in 0..2 {
31///     for c in 0..3 {
32///         for i in 0..6 {
33///             for j in 0..6 {
34///                 // Create input data with an easily observable pattern
35///                 input_data[[b, c, i, j]] = (i * j) as f32 + b as f32 * 0.1 + c as f32 * 0.01;
36///             }
37///         }
38///     }
39/// }
40///
41/// let x = input_data.clone().into_dyn();
42///
43/// // Test using MaxPooling2D in a model
44/// let mut model = Sequential::new();
45/// model
46///     .add(MaxPooling2D::new(
47///         (2, 2),           // Pool window size
48///         vec![2, 3, 6, 6], // Input shape
49///         None,             // Use default stride (2,2)
50///     ))
51///     .compile(RMSprop::new(0.001, 0.9, 1e-8), MeanSquaredError::new());
52///
53/// // Create target tensor - corresponding to the pooled shape
54/// let y = Array4::ones((2, 3, 3, 3)).into_dyn();
55///
56/// // Print model structure
57/// model.summary();
58///
59/// // Train the model (run a few epochs)
60/// model.fit(&x, &y, 3).unwrap();
61///
62/// // Use predict for forward propagation prediction
63/// let prediction = model.predict(&x);
64/// println!("MaxPooling2D prediction results: {:?}", prediction);
65///
66/// // Check if output shape is correct
67/// assert_eq!(prediction.shape(), &[2, 3, 3, 3]);
68/// ```
69pub struct MaxPooling2D {
70    pool_size: (usize, usize),
71    strides: (usize, usize),
72    input_shape: Vec<usize>,
73    input_cache: Option<Tensor>,
74    max_positions: Option<Vec<(usize, usize, usize, usize)>>,
75}
76
77impl MaxPooling2D {
78    /// Creates a new 2D max pooling layer.
79    ///
80    /// # Parameters
81    ///
82    /// - `pool_size` - Size of the pooling window, expressed as (height, width).
83    /// - `input_shape` - Shape of the input tensor, in format \[batch_size, channels, height, width\].
84    /// - `strides` - Stride of the pooling operation, expressed as (vertical stride, horizontal stride). If None, uses the same value as pool_size.
85    ///
86    /// # Returns
87    ///
88    /// * `Self` - A new instance of the MaxPooling2D layer.
89    pub fn new(
90        pool_size: (usize, usize),
91        input_shape: Vec<usize>,
92        strides: Option<(usize, usize)>,
93    ) -> Self {
94        // If stride is not specified, use the same stride as pool size
95        let strides = strides.unwrap_or(pool_size);
96
97        MaxPooling2D {
98            pool_size,
99            strides,
100            input_shape,
101            input_cache: None,
102            max_positions: None,
103        }
104    }
105
106    /// Calculates the output shape of the max pooling layer.
107    ///
108    /// # Parameters
109    ///
110    /// * `input_shape` - Shape of the input tensor, in format \[batch_size, channels, height, width\].
111    ///
112    /// # Returns
113    ///
114    /// A vector containing the calculated output shape, in format \[batch_size, channels, output_height, output_width\].
115    fn calculate_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
116        let batch_size = input_shape[0];
117        let channels = input_shape[1];
118        let input_height = input_shape[2];
119        let input_width = input_shape[3];
120
121        // Calculate the height and width of the output
122        let output_height = (input_height - self.pool_size.0) / self.strides.0 + 1;
123        let output_width = (input_width - self.pool_size.1) / self.strides.1 + 1;
124
125        vec![batch_size, channels, output_height, output_width]
126    }
127
128    /// Performs max pooling operation.
129    ///
130    /// # Parameters
131    ///
132    /// * `input` - Input tensor with shape \[batch_size, channels, height, width\].
133    ///
134    /// # Returns
135    ///
136    /// * `(Tensor, Vec<(usize, usize, usize, usize)>)` - Result of the pooling operation and positions of maximum values.
137    fn max_pool(&self, input: &Tensor) -> (Tensor, Vec<(usize, usize, usize, usize)>) {
138        let input_shape = input.shape();
139        let batch_size = input_shape[0];
140        let channels = input_shape[1];
141        let output_shape = self.calculate_output_shape(input_shape);
142
143        // Pre-allocate output array
144        let mut output = ArrayD::zeros(output_shape.clone());
145        // Vector to store positions of maximum values
146        let mut max_positions = Vec::new();
147
148        // Process each batch and channel in parallel
149        let results: Vec<_> = (0..batch_size)
150            .into_par_iter()
151            .flat_map(|b| {
152                // Clone output_shape here to avoid ownership movement issues
153                let output_shape_clone = output_shape.clone();
154                (0..channels).into_par_iter().map(move |c| {
155                    let mut batch_channel_output = Vec::new();
156                    let mut batch_channel_positions = Vec::new();
157
158                    // Perform pooling for each output position
159                    for i in 0..output_shape_clone[2] {
160                        let i_start = i * self.strides.0;
161
162                        for j in 0..output_shape_clone[3] {
163                            let j_start = j * self.strides.1;
164
165                            // Find maximum value in pooling window
166                            let mut max_val = f32::NEG_INFINITY;
167                            let mut max_pos = (0, 0);
168
169                            for di in 0..self.pool_size.0 {
170                                let i_pos = i_start + di;
171                                if i_pos >= input_shape[2] {
172                                    continue;
173                                }
174
175                                for dj in 0..self.pool_size.1 {
176                                    let j_pos = j_start + dj;
177                                    if j_pos >= input_shape[3] {
178                                        continue;
179                                    }
180
181                                    let val = input[[b, c, i_pos, j_pos]];
182                                    if val > max_val {
183                                        max_val = val;
184                                        max_pos = (i_pos, j_pos);
185                                    }
186                                }
187                            }
188
189                            batch_channel_output.push((i, j, max_val));
190                            batch_channel_positions.push((b, c, max_pos.0, max_pos.1));
191                        }
192                    }
193
194                    ((b, c), (batch_channel_output, batch_channel_positions))
195                })
196            })
197            .collect();
198
199        // Merge results into output tensor
200        for ((b, c), (outputs, positions)) in results {
201            for (i, j, val) in outputs {
202                output[[b, c, i, j]] = val;
203            }
204            max_positions.extend(positions);
205        }
206
207        (output, max_positions)
208    }
209}
210
211impl Layer for MaxPooling2D {
212    fn forward(&mut self, input: &Tensor) -> Tensor {
213        // Save input for backpropagation
214        self.input_cache = Some(input.clone());
215
216        // Perform max pooling operation
217        let (output, max_positions) = self.max_pool(input);
218
219        // Store maximum value positions for backpropagation
220        self.max_positions = Some(max_positions);
221
222        output
223    }
224
225    fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
226        if let (Some(input), Some(max_positions)) = (&self.input_cache, &self.max_positions) {
227            let grad_shape = grad_output.shape();
228
229            // Initialize input gradients with same shape as input
230            let mut input_gradients = ArrayD::zeros(input.dim());
231
232            // Create a vector containing update positions and values
233            let gradient_updates: Vec<_> = max_positions
234                .par_iter()
235                .filter_map(|&(b, c, i, j)| {
236                    // Calculate corresponding output gradient index
237                    let out_i = i / self.strides.0;
238                    let out_j = j / self.strides.1;
239
240                    // Ensure indices are within valid range
241                    if out_i < grad_shape[2] && out_j < grad_shape[3] {
242                        // Return index and gradient value
243                        Some(((b, c, i, j), grad_output[[b, c, out_i, out_j]]))
244                    } else {
245                        None
246                    }
247                })
248                .collect();
249
250            // Apply gradient updates sequentially
251            for ((b, c, i, j), grad_val) in gradient_updates {
252                input_gradients[[b, c, i, j]] = grad_val;
253            }
254
255            Ok(input_gradients)
256        } else {
257            Err(ModelError::ProcessingError(
258                "Forward pass has not been run".to_string(),
259            ))
260        }
261    }
262
263    fn layer_type(&self) -> &str {
264        "MaxPooling2D"
265    }
266
267    fn output_shape(&self) -> String {
268        let output_shape = self.calculate_output_shape(&self.input_shape);
269        format!(
270            "({}, {}, {}, {})",
271            output_shape[0], output_shape[1], output_shape[2], output_shape[3]
272        )
273    }
274
275    fn param_count(&self) -> usize {
276        // Pooling layer has no trainable parameters
277        0
278    }
279
280    // Pooling layer has no trainable parameters, so these methods do nothing
281    fn update_parameters_sgd(&mut self, _lr: f32) {
282        // Max pooling layer has no parameters to update
283    }
284
285    fn update_parameters_adam(
286        &mut self,
287        _lr: f32,
288        _beta1: f32,
289        _beta2: f32,
290        _epsilon: f32,
291        _t: u64,
292    ) {
293        // Max pooling layer has no parameters to update
294    }
295
296    fn update_parameters_rmsprop(&mut self, _lr: f32, _rho: f32, _epsilon: f32) {
297        // Max pooling layer has no parameters to update
298    }
299
300    fn get_weights(&self) -> super::LayerWeight {
301        // Max pooling layer has no weights
302        super::LayerWeight::Empty
303    }
304}