pub fn max_pool2d_backward<F>(
grad_output: &ArrayView4<'_, F>,
indices: &ArrayView4<'_, usize>,
inputshape: (usize, usize, usize, usize),
) -> LinalgResult<Array4<F>>
Expand description
Perform the backward pass of max pooling operation
Takes the gradients of the pooled outputs and distributes them back to the locations of the maximum values in the original input.
§Arguments
grad_output
- Gradient of the output tensor of shape (batchsize, channels, output_height, output_width)indices
- Indices of the maximum values from the forward passinputshape
- Shape of the original input tensor (batchsize, channels, height, width)
§Returns
- Gradient with respect to input
§Examples
use scirs2_core::ndarray::Array4;
use scirs2_linalg::convolution::{max_pool2d, max_pool2d_backward};
// Create a 1x1x4x4 input tensor
let mut input = Array4::<f32>::zeros((1, 1, 4, 4));
// Fill with sample data
for h in 0..4 {
for w in 0..4 {
input[[0, 0, h, w]] = (h * 4 + w) as f32;
}
}
// Apply max pooling (forward pass)
let (output, indices) = max_pool2d(&input.view(), (2, 2), (2, 2), (0, 0)).unwrap();
// Create gradient of the output
let mut grad_output = Array4::<f32>::ones((1, 1, 2, 2));
// Compute gradient of the input (backward pass)
let grad_input = max_pool2d_backward(
&grad_output.view(),
&indices.view(),
(1, 1, 4, 4),
).unwrap();
// Verify shape
assert_eq!(grad_input.shape(), &[1, 1, 4, 4]);