Skip to main content

rustorch_core/ops/
pool.rs

1use crate::autograd::BackwardOp;
2use crate::storage::Storage;
3use crate::Tensor;
4use rayon::prelude::*;
5use std::sync::Arc;
6
7// --- MaxPool2d ---
8// Input: (N, C, H, W)
9// Output: (N, C, H_out, W_out)
10// H_out = (H + 2*pad - kernel_size) / stride + 1
11
12#[derive(Debug)]
13pub struct MaxPool2dBackward {
14    pub input: Tensor,
15    pub kernel_size: (usize, usize),
16    pub stride: (usize, usize),
17    pub padding: (usize, usize),
18}
19
20impl BackwardOp for MaxPool2dBackward {
21    fn backward(&self, grad: &Tensor) {
22        if self.input.requires_grad() {
23            let (k_h, k_w) = self.kernel_size;
24            let (stride_h, stride_w) = self.stride;
25            let (pad_h, pad_w) = self.padding;
26
27            let input_shape = self.input.shape();
28            let grad_shape = grad.shape();
29
30            let n = input_shape[0];
31            let c = input_shape[1];
32            let h_in = input_shape[2];
33            let w_in = input_shape[3];
34
35            let h_out = grad_shape[2];
36            let w_out = grad_shape[3];
37
38            let input_guard = self.input.data();
39            let grad_guard = grad.data();
40            let input_data = &*input_guard;
41            let grad_data = &*grad_guard;
42
43            // We need to scatter gradients back to the max indices.
44            // Since multiple output pixels might map to same input pixel (with overlap), we accumulate.
45            // However, maxpool usually takes the gradient from the max position.
46
47            // Since we can't easily do scatter_add in parallel without atomics or locking,
48            // let's iterate over output and add to a local buffer, then reduce?
49            // Or use sequential update for simplicity first, or parallel over N, C.
50
51            let mut grad_input_data = vec![0.0; n * c * h_in * w_in];
52
53            // For MaxPool backward, we need to find which index was the max.
54            // We re-compute the forward pass window to find the index.
55
56            // Parallelize over N, C
57            // Note: Parallel writing to grad_input_data is unsafe if windows overlap.
58            // But MaxPool windows usually stride >= kernel_size for non-overlapping.
59            // If they overlap, we need atomic adds.
60            // For now, let's assume standard non-overlapping or handle overlap sequentially within a thread?
61            // Actually, if stride < kernel_size, multiple output pixels depend on same input.
62            // So we can't parallelize purely by output pixel without atomic add to input.
63            // But we CAN parallelize by N and C, as they are independent.
64
65            // We can use chunks_mut to split grad_input_data by N*C
66            let chunk_size = h_in * w_in;
67            grad_input_data
68                .par_chunks_mut(chunk_size)
69                .enumerate()
70                .for_each(|(i, grad_in_chunk)| {
71                    let b = i / c;
72                    let ci = i % c;
73
74                    // Corresponding section in input and grad
75                    let input_offset = (b * c + ci) * h_in * w_in;
76                    let grad_offset = (b * c + ci) * h_out * w_out;
77
78                    for ho in 0..h_out {
79                        for wo in 0..w_out {
80                            let h_start = (ho * stride_h).saturating_sub(pad_h);
81                            let w_start = (wo * stride_w).saturating_sub(pad_w);
82                            let h_end = (h_start + k_h).min(h_in);
83                            let w_end = (w_start + k_w).min(w_in);
84
85                            // Find max in window
86                            let mut max_val = -f32::INFINITY;
87                            let mut max_idx = (h_start, w_start); // Default to start
88
89                            for h in h_start..h_end {
90                                for w in w_start..w_end {
91                                    let val = input_data[input_offset + h * w_in + w];
92                                    if val > max_val {
93                                        max_val = val;
94                                        max_idx = (h, w);
95                                    }
96                                }
97                            }
98
99                            // Add gradient to max index
100                            // Safety: max_idx is within h_in, w_in bounds
101                            let g_val = grad_data[grad_offset + ho * w_out + wo];
102                            grad_in_chunk[max_idx.0 * w_in + max_idx.1] += g_val;
103                        }
104                    }
105                });
106
107            let grad_input_tensor =
108                Tensor::new_with_storage(Storage::new(grad_input_data), self.input.shape());
109            self.input.accumulate_grad(&grad_input_tensor);
110            self.input.backward_step();
111        }
112    }
113}
114
115pub fn max_pool2d(
116    input: &Tensor,
117    kernel_size: (usize, usize),
118    stride: (usize, usize),
119    padding: (usize, usize),
120) -> Tensor {
121    let shape = input.shape();
122    if shape.len() != 4 {
123        panic!("MaxPool2d requires 4D tensor (N, C, H, W)");
124    }
125
126    let n = shape[0];
127    let c = shape[1];
128    let h_in = shape[2];
129    let w_in = shape[3];
130
131    let (k_h, k_w) = kernel_size;
132    let (stride_h, stride_w) = stride;
133    let (pad_h, pad_w) = padding;
134
135    let h_out = (h_in + 2 * pad_h - k_h) / stride_h + 1;
136    let w_out = (w_in + 2 * pad_w - k_w) / stride_w + 1;
137
138    let input_guard = input.data();
139    let input_data = &*input_guard;
140
141    let total_elements = n * c * h_out * w_out;
142    let result_data: Vec<f32> = (0..total_elements)
143        .into_par_iter()
144        .map(|idx| {
145            let wo = idx % w_out;
146            let ho = (idx / w_out) % h_out;
147            let ci = (idx / (w_out * h_out)) % c;
148            let b = idx / (w_out * h_out * c);
149
150            let h_start_raw = (ho * stride_h) as isize - pad_h as isize;
151            let w_start_raw = (wo * stride_w) as isize - pad_w as isize;
152
153            let mut max_val = -f32::INFINITY;
154
155            for kh in 0..k_h {
156                for kw in 0..k_w {
157                    let h_in_idx = h_start_raw + kh as isize;
158                    let w_in_idx = w_start_raw + kw as isize;
159
160                    if h_in_idx >= 0
161                        && h_in_idx < h_in as isize
162                        && w_in_idx >= 0
163                        && w_in_idx < w_in as isize
164                    {
165                        let val = input_data
166                            [((b * c + ci) * h_in + h_in_idx as usize) * w_in + w_in_idx as usize];
167                        if val > max_val {
168                            max_val = val;
169                        }
170                    }
171                }
172            }
173            max_val
174        })
175        .collect();
176
177    let storage = Storage::new(result_data);
178    let mut tensor = Tensor::new_with_storage(storage, &[n, c, h_out, w_out]);
179
180    if input.requires_grad() {
181        tensor.set_requires_grad_mut(true);
182        tensor.set_op(Arc::new(MaxPool2dBackward {
183            input: input.clone(),
184            kernel_size,
185            stride,
186            padding,
187        }));
188    }
189
190    tensor
191}