rustorch_core/ops/
pool.rs1use crate::autograd::BackwardOp;
2use crate::storage::Storage;
3use crate::Tensor;
4use rayon::prelude::*;
5use std::sync::Arc;
6
7#[derive(Debug)]
13pub struct MaxPool2dBackward {
14 pub input: Tensor,
15 pub kernel_size: (usize, usize),
16 pub stride: (usize, usize),
17 pub padding: (usize, usize),
18}
19
20impl BackwardOp for MaxPool2dBackward {
21 fn backward(&self, grad: &Tensor) {
22 if self.input.requires_grad() {
23 let (k_h, k_w) = self.kernel_size;
24 let (stride_h, stride_w) = self.stride;
25 let (pad_h, pad_w) = self.padding;
26
27 let input_shape = self.input.shape();
28 let grad_shape = grad.shape();
29
30 let n = input_shape[0];
31 let c = input_shape[1];
32 let h_in = input_shape[2];
33 let w_in = input_shape[3];
34
35 let h_out = grad_shape[2];
36 let w_out = grad_shape[3];
37
38 let input_guard = self.input.data();
39 let grad_guard = grad.data();
40 let input_data = &*input_guard;
41 let grad_data = &*grad_guard;
42
43 let mut grad_input_data = vec![0.0; n * c * h_in * w_in];
52
53 let chunk_size = h_in * w_in;
67 grad_input_data
68 .par_chunks_mut(chunk_size)
69 .enumerate()
70 .for_each(|(i, grad_in_chunk)| {
71 let b = i / c;
72 let ci = i % c;
73
74 let input_offset = (b * c + ci) * h_in * w_in;
76 let grad_offset = (b * c + ci) * h_out * w_out;
77
78 for ho in 0..h_out {
79 for wo in 0..w_out {
80 let h_start = (ho * stride_h).saturating_sub(pad_h);
81 let w_start = (wo * stride_w).saturating_sub(pad_w);
82 let h_end = (h_start + k_h).min(h_in);
83 let w_end = (w_start + k_w).min(w_in);
84
85 let mut max_val = -f32::INFINITY;
87 let mut max_idx = (h_start, w_start); for h in h_start..h_end {
90 for w in w_start..w_end {
91 let val = input_data[input_offset + h * w_in + w];
92 if val > max_val {
93 max_val = val;
94 max_idx = (h, w);
95 }
96 }
97 }
98
99 let g_val = grad_data[grad_offset + ho * w_out + wo];
102 grad_in_chunk[max_idx.0 * w_in + max_idx.1] += g_val;
103 }
104 }
105 });
106
107 let grad_input_tensor =
108 Tensor::new_with_storage(Storage::new(grad_input_data), self.input.shape());
109 self.input.accumulate_grad(&grad_input_tensor);
110 self.input.backward_step();
111 }
112 }
113}
114
115pub fn max_pool2d(
116 input: &Tensor,
117 kernel_size: (usize, usize),
118 stride: (usize, usize),
119 padding: (usize, usize),
120) -> Tensor {
121 let shape = input.shape();
122 if shape.len() != 4 {
123 panic!("MaxPool2d requires 4D tensor (N, C, H, W)");
124 }
125
126 let n = shape[0];
127 let c = shape[1];
128 let h_in = shape[2];
129 let w_in = shape[3];
130
131 let (k_h, k_w) = kernel_size;
132 let (stride_h, stride_w) = stride;
133 let (pad_h, pad_w) = padding;
134
135 let h_out = (h_in + 2 * pad_h - k_h) / stride_h + 1;
136 let w_out = (w_in + 2 * pad_w - k_w) / stride_w + 1;
137
138 let input_guard = input.data();
139 let input_data = &*input_guard;
140
141 let total_elements = n * c * h_out * w_out;
142 let result_data: Vec<f32> = (0..total_elements)
143 .into_par_iter()
144 .map(|idx| {
145 let wo = idx % w_out;
146 let ho = (idx / w_out) % h_out;
147 let ci = (idx / (w_out * h_out)) % c;
148 let b = idx / (w_out * h_out * c);
149
150 let h_start_raw = (ho * stride_h) as isize - pad_h as isize;
151 let w_start_raw = (wo * stride_w) as isize - pad_w as isize;
152
153 let mut max_val = -f32::INFINITY;
154
155 for kh in 0..k_h {
156 for kw in 0..k_w {
157 let h_in_idx = h_start_raw + kh as isize;
158 let w_in_idx = w_start_raw + kw as isize;
159
160 if h_in_idx >= 0
161 && h_in_idx < h_in as isize
162 && w_in_idx >= 0
163 && w_in_idx < w_in as isize
164 {
165 let val = input_data
166 [((b * c + ci) * h_in + h_in_idx as usize) * w_in + w_in_idx as usize];
167 if val > max_val {
168 max_val = val;
169 }
170 }
171 }
172 }
173 max_val
174 })
175 .collect();
176
177 let storage = Storage::new(result_data);
178 let mut tensor = Tensor::new_with_storage(storage, &[n, c, h_out, w_out]);
179
180 if input.requires_grad() {
181 tensor.set_requires_grad_mut(true);
182 tensor.set_op(Arc::new(MaxPool2dBackward {
183 input: input.clone(),
184 kernel_size,
185 stride,
186 padding,
187 }));
188 }
189
190 tensor
191}