max_pool2d_backward

Function max_pool2d_backward 

Source
pub fn max_pool2d_backward<F>(
    grad_output: &ArrayView4<'_, F>,
    indices: &ArrayView4<'_, usize>,
    inputshape: (usize, usize, usize, usize),
) -> LinalgResult<Array4<F>>
Expand description

Perform the backward pass of max pooling operation

Takes the gradients of the pooled outputs and distributes them back to the locations of the maximum values in the original input.

§Arguments

  • grad_output - Gradient of the output tensor of shape (batchsize, channels, output_height, output_width)
  • indices - Indices of the maximum values from the forward pass
  • inputshape - Shape of the original input tensor (batchsize, channels, height, width)

§Returns

  • Gradient with respect to input

§Examples

use scirs2_core::ndarray::Array4;
use scirs2_linalg::convolution::{max_pool2d, max_pool2d_backward};

// Create a 1x1x4x4 input tensor
let mut input = Array4::<f32>::zeros((1, 1, 4, 4));
// Fill with sample data
for h in 0..4 {
    for w in 0..4 {
        input[[0, 0, h, w]] = (h * 4 + w) as f32;
    }
}

// Apply max pooling (forward pass)
let (output, indices) = max_pool2d(&input.view(), (2, 2), (2, 2), (0, 0)).unwrap();

// Create gradient of the output
let mut grad_output = Array4::<f32>::ones((1, 1, 2, 2));

// Compute gradient of the input (backward pass)
let grad_input = max_pool2d_backward(
    &grad_output.view(),
    &indices.view(),
    (1, 1, 4, 4),
).unwrap();

// Verify shape
assert_eq!(grad_input.shape(), &[1, 1, 4, 4]);