rustyml/neural_network/layer/max_pooling_2d.rs
1use crate::neural_network::{ModelError, Tensor};
2use crate::traits::Layer;
3use ndarray::ArrayD;
4use rayon::prelude::*;
5
6/// Defines a structure for max pooling operation, used to perform max pooling on 2D data.
7///
8/// Max pooling is a common downsampling technique in CNNs that reduces
9/// the spatial dimensions of feature maps by selecting the maximum value in each pooling window,
10/// thereby reducing computation and controlling overfitting.
11///
12/// # Fields
13///
14/// - `pool_size` - Size of the pooling window, expressed as (height, width).
15/// - `strides` - Stride of the pooling operation, expressed as (vertical stride, horizontal stride).
16/// - `input_shape` - Shape of the input tensor.
17/// - `input_cache` - Cached input data, used for backpropagation.
18/// - `max_positions` - Cache of maximum value positions, used for backpropagation.
19///
20/// # Example
21/// ```rust
22/// use rustyml::prelude::*;
23/// use ndarray::Array4;
24///
25/// // Create a simple 4D input tensor: [batch_size, channels, height, width]
26/// // Batch size=2, 3 input channels, 6x6 pixels
27/// let mut input_data = Array4::zeros((2, 3, 6, 6));
28///
29/// // Set some specific values so we can predict the max pooling result
30/// for b in 0..2 {
31/// for c in 0..3 {
32/// for i in 0..6 {
33/// for j in 0..6 {
34/// // Create input data with an easily observable pattern
35/// input_data[[b, c, i, j]] = (i * j) as f32 + b as f32 * 0.1 + c as f32 * 0.01;
36/// }
37/// }
38/// }
39/// }
40///
41/// let x = input_data.clone().into_dyn();
42///
43/// // Test using MaxPooling2D in a model
44/// let mut model = Sequential::new();
45/// model
46/// .add(MaxPooling2D::new(
47/// (2, 2), // Pool window size
48/// vec![2, 3, 6, 6], // Input shape
49/// None, // Use default stride (2,2)
50/// ))
51/// .compile(RMSprop::new(0.001, 0.9, 1e-8), MeanSquaredError::new());
52///
53/// // Create target tensor - corresponding to the pooled shape
54/// let y = Array4::ones((2, 3, 3, 3)).into_dyn();
55///
56/// // Print model structure
57/// model.summary();
58///
59/// // Train the model (run a few epochs)
60/// model.fit(&x, &y, 3).unwrap();
61///
62/// // Use predict for forward propagation prediction
63/// let prediction = model.predict(&x);
64/// println!("MaxPooling2D prediction results: {:?}", prediction);
65///
66/// // Check if output shape is correct
67/// assert_eq!(prediction.shape(), &[2, 3, 3, 3]);
68/// ```
69pub struct MaxPooling2D {
70 pool_size: (usize, usize),
71 strides: (usize, usize),
72 input_shape: Vec<usize>,
73 input_cache: Option<Tensor>,
74 max_positions: Option<Vec<(usize, usize, usize, usize)>>,
75}
76
77impl MaxPooling2D {
78 /// Creates a new 2D max pooling layer.
79 ///
80 /// # Parameters
81 ///
82 /// - `pool_size` - Size of the pooling window, expressed as (height, width).
83 /// - `input_shape` - Shape of the input tensor, in format \[batch_size, channels, height, width\].
84 /// - `strides` - Stride of the pooling operation, expressed as (vertical stride, horizontal stride). If None, uses the same value as pool_size.
85 ///
86 /// # Returns
87 ///
88 /// * `Self` - A new instance of the MaxPooling2D layer.
89 pub fn new(
90 pool_size: (usize, usize),
91 input_shape: Vec<usize>,
92 strides: Option<(usize, usize)>,
93 ) -> Self {
94 // If stride is not specified, use the same stride as pool size
95 let strides = strides.unwrap_or(pool_size);
96
97 MaxPooling2D {
98 pool_size,
99 strides,
100 input_shape,
101 input_cache: None,
102 max_positions: None,
103 }
104 }
105
106 /// Calculates the output shape of the max pooling layer.
107 ///
108 /// # Parameters
109 ///
110 /// * `input_shape` - Shape of the input tensor, in format \[batch_size, channels, height, width\].
111 ///
112 /// # Returns
113 ///
114 /// A vector containing the calculated output shape, in format \[batch_size, channels, output_height, output_width\].
115 fn calculate_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
116 let batch_size = input_shape[0];
117 let channels = input_shape[1];
118 let input_height = input_shape[2];
119 let input_width = input_shape[3];
120
121 // Calculate the height and width of the output
122 let output_height = (input_height - self.pool_size.0) / self.strides.0 + 1;
123 let output_width = (input_width - self.pool_size.1) / self.strides.1 + 1;
124
125 vec![batch_size, channels, output_height, output_width]
126 }
127
128 /// Performs max pooling operation.
129 ///
130 /// # Parameters
131 ///
132 /// * `input` - Input tensor with shape \[batch_size, channels, height, width\].
133 ///
134 /// # Returns
135 ///
136 /// * `(Tensor, Vec<(usize, usize, usize, usize)>)` - Result of the pooling operation and positions of maximum values.
137 fn max_pool(&self, input: &Tensor) -> (Tensor, Vec<(usize, usize, usize, usize)>) {
138 let input_shape = input.shape();
139 let batch_size = input_shape[0];
140 let channels = input_shape[1];
141 let output_shape = self.calculate_output_shape(input_shape);
142
143 // Pre-allocate output array
144 let mut output = ArrayD::zeros(output_shape.clone());
145 // Vector to store positions of maximum values
146 let mut max_positions = Vec::new();
147
148 // Process each batch and channel in parallel
149 let results: Vec<_> = (0..batch_size)
150 .into_par_iter()
151 .flat_map(|b| {
152 // Clone output_shape here to avoid ownership movement issues
153 let output_shape_clone = output_shape.clone();
154 (0..channels).into_par_iter().map(move |c| {
155 let mut batch_channel_output = Vec::new();
156 let mut batch_channel_positions = Vec::new();
157
158 // Perform pooling for each output position
159 for i in 0..output_shape_clone[2] {
160 let i_start = i * self.strides.0;
161
162 for j in 0..output_shape_clone[3] {
163 let j_start = j * self.strides.1;
164
165 // Find maximum value in pooling window
166 let mut max_val = f32::NEG_INFINITY;
167 let mut max_pos = (0, 0);
168
169 for di in 0..self.pool_size.0 {
170 let i_pos = i_start + di;
171 if i_pos >= input_shape[2] {
172 continue;
173 }
174
175 for dj in 0..self.pool_size.1 {
176 let j_pos = j_start + dj;
177 if j_pos >= input_shape[3] {
178 continue;
179 }
180
181 let val = input[[b, c, i_pos, j_pos]];
182 if val > max_val {
183 max_val = val;
184 max_pos = (i_pos, j_pos);
185 }
186 }
187 }
188
189 batch_channel_output.push((i, j, max_val));
190 batch_channel_positions.push((b, c, max_pos.0, max_pos.1));
191 }
192 }
193
194 ((b, c), (batch_channel_output, batch_channel_positions))
195 })
196 })
197 .collect();
198
199 // Merge results into output tensor
200 for ((b, c), (outputs, positions)) in results {
201 for (i, j, val) in outputs {
202 output[[b, c, i, j]] = val;
203 }
204 max_positions.extend(positions);
205 }
206
207 (output, max_positions)
208 }
209}
210
211impl Layer for MaxPooling2D {
212 fn forward(&mut self, input: &Tensor) -> Tensor {
213 // Save input for backpropagation
214 self.input_cache = Some(input.clone());
215
216 // Perform max pooling operation
217 let (output, max_positions) = self.max_pool(input);
218
219 // Store maximum value positions for backpropagation
220 self.max_positions = Some(max_positions);
221
222 output
223 }
224
225 fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
226 if let (Some(input), Some(max_positions)) = (&self.input_cache, &self.max_positions) {
227 let grad_shape = grad_output.shape();
228
229 // Initialize input gradients with same shape as input
230 let mut input_gradients = ArrayD::zeros(input.dim());
231
232 // Create a vector containing update positions and values
233 let gradient_updates: Vec<_> = max_positions
234 .par_iter()
235 .filter_map(|&(b, c, i, j)| {
236 // Calculate corresponding output gradient index
237 let out_i = i / self.strides.0;
238 let out_j = j / self.strides.1;
239
240 // Ensure indices are within valid range
241 if out_i < grad_shape[2] && out_j < grad_shape[3] {
242 // Return index and gradient value
243 Some(((b, c, i, j), grad_output[[b, c, out_i, out_j]]))
244 } else {
245 None
246 }
247 })
248 .collect();
249
250 // Apply gradient updates sequentially
251 for ((b, c, i, j), grad_val) in gradient_updates {
252 input_gradients[[b, c, i, j]] = grad_val;
253 }
254
255 Ok(input_gradients)
256 } else {
257 Err(ModelError::ProcessingError(
258 "Forward pass has not been run".to_string(),
259 ))
260 }
261 }
262
263 fn layer_type(&self) -> &str {
264 "MaxPooling2D"
265 }
266
267 fn output_shape(&self) -> String {
268 let output_shape = self.calculate_output_shape(&self.input_shape);
269 format!(
270 "({}, {}, {}, {})",
271 output_shape[0], output_shape[1], output_shape[2], output_shape[3]
272 )
273 }
274
275 fn param_count(&self) -> usize {
276 // Pooling layer has no trainable parameters
277 0
278 }
279
280 // Pooling layer has no trainable parameters, so these methods do nothing
281 fn update_parameters_sgd(&mut self, _lr: f32) {
282 // Max pooling layer has no parameters to update
283 }
284
285 fn update_parameters_adam(
286 &mut self,
287 _lr: f32,
288 _beta1: f32,
289 _beta2: f32,
290 _epsilon: f32,
291 _t: u64,
292 ) {
293 // Max pooling layer has no parameters to update
294 }
295
296 fn update_parameters_rmsprop(&mut self, _lr: f32, _rho: f32, _epsilon: f32) {
297 // Max pooling layer has no parameters to update
298 }
299
300 fn get_weights(&self) -> super::LayerWeight {
301 // Max pooling layer has no weights
302 super::LayerWeight::Empty
303 }
304}