1use super::module::PyModule;
4use crate::{error::PyResult, py_result, tensor::PyTensor};
5use pyo3::prelude::*;
6use pyo3::types::PyAny;
7use std::collections::HashMap;
8
9#[pyclass(name = "MaxPool2d", extends = PyModule)]
11pub struct PyMaxPool2d {
12 kernel_size: (usize, usize),
13 stride: Option<(usize, usize)>,
14 padding: (usize, usize),
15 dilation: (usize, usize),
16 ceil_mode: bool,
17 return_indices: bool,
18}
19
20#[pymethods]
21impl PyMaxPool2d {
22 #[new]
23 fn new(
24 kernel_size: Py<PyAny>,
25 stride: Option<Py<PyAny>>,
26 padding: Option<Py<PyAny>>,
27 dilation: Option<Py<PyAny>>,
28 ceil_mode: Option<bool>,
29 return_indices: Option<bool>,
30 ) -> PyResult<(Self, PyModule)> {
31 let kernel_size = Python::attach(|py| -> PyResult<(usize, usize)> {
33 if let Ok(size) = kernel_size.extract::<usize>(py) {
34 Ok((size, size))
35 } else if let Ok(tuple) = kernel_size.extract::<(usize, usize)>(py) {
36 Ok(tuple)
37 } else {
38 Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
39 "kernel_size must be an integer or tuple of integers",
40 ))
41 }
42 })?;
43
44 let stride = if let Some(stride_obj) = stride {
46 Some(Python::attach(|py| -> PyResult<(usize, usize)> {
47 if let Ok(stride) = stride_obj.extract::<usize>(py) {
48 Ok((stride, stride))
49 } else if let Ok(tuple) = stride_obj.extract::<(usize, usize)>(py) {
50 Ok(tuple)
51 } else {
52 Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
53 "stride must be an integer or tuple of integers",
54 ))
55 }
56 })?)
57 } else {
58 None
59 };
60
61 let padding = if let Some(padding_obj) = padding {
63 Python::attach(|py| -> PyResult<(usize, usize)> {
64 if let Ok(padding) = padding_obj.extract::<usize>(py) {
65 Ok((padding, padding))
66 } else if let Ok(tuple) = padding_obj.extract::<(usize, usize)>(py) {
67 Ok(tuple)
68 } else {
69 Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
70 "padding must be an integer or tuple of integers",
71 ))
72 }
73 })?
74 } else {
75 (0, 0)
76 };
77
78 let dilation = if let Some(dilation_obj) = dilation {
80 Python::attach(|py| -> PyResult<(usize, usize)> {
81 if let Ok(dilation) = dilation_obj.extract::<usize>(py) {
82 Ok((dilation, dilation))
83 } else if let Ok(tuple) = dilation_obj.extract::<(usize, usize)>(py) {
84 Ok(tuple)
85 } else {
86 Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
87 "dilation must be an integer or tuple of integers",
88 ))
89 }
90 })?
91 } else {
92 (1, 1)
93 };
94
95 Ok((
96 Self {
97 kernel_size,
98 stride,
99 padding,
100 dilation,
101 ceil_mode: ceil_mode.unwrap_or(false),
102 return_indices: return_indices.unwrap_or(false),
103 },
104 PyModule::new(),
105 ))
106 }
107
108 fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
110 let shape = input.tensor.shape().dims().to_vec();
112
113 if shape.len() != 4 {
115 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
116 "Expected 4D input (NCHW), got {}D",
117 shape.len()
118 )));
119 }
120
121 let (batch_size, channels, in_h, in_w) = (shape[0], shape[1], shape[2], shape[3]);
122 let (kh, kw) = self.kernel_size;
123 let (stride_h, stride_w) = self.stride.unwrap_or(self.kernel_size);
124 let (pad_h, pad_w) = self.padding;
125
126 let out_h = if self.ceil_mode {
128 ((in_h + 2 * pad_h - kh) as f32 / stride_h as f32).ceil() as usize + 1
129 } else {
130 (in_h + 2 * pad_h - kh) / stride_h + 1
131 };
132 let out_w = if self.ceil_mode {
133 ((in_w + 2 * pad_w - kw) as f32 / stride_w as f32).ceil() as usize + 1
134 } else {
135 (in_w + 2 * pad_w - kw) / stride_w + 1
136 };
137
138 let input_data = py_result!(input.tensor.data())?;
139 let mut output_data = vec![f32::NEG_INFINITY; batch_size * channels * out_h * out_w];
140
141 for b in 0..batch_size {
143 for c in 0..channels {
144 for oh in 0..out_h {
145 for ow in 0..out_w {
146 let mut max_val = f32::NEG_INFINITY;
147
148 for kh_idx in 0..kh {
149 for kw_idx in 0..kw {
150 let ih = (oh * stride_h + kh_idx) as i32 - pad_h as i32;
151 let iw = (ow * stride_w + kw_idx) as i32 - pad_w as i32;
152
153 if ih >= 0 && ih < in_h as i32 && iw >= 0 && iw < in_w as i32 {
154 let input_idx = b * channels * in_h * in_w
155 + c * in_h * in_w
156 + ih as usize * in_w
157 + iw as usize;
158 max_val = max_val.max(input_data[input_idx]);
159 }
160 }
161 }
162
163 let output_idx =
164 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
165 output_data[output_idx] = max_val;
166 }
167 }
168 }
169 }
170
171 let result = py_result!(torsh_tensor::Tensor::from_data(
172 output_data,
173 vec![batch_size, channels, out_h, out_w],
174 input.tensor.device()
175 ))?;
176
177 Ok(PyTensor { tensor: result })
178 }
179
180 fn parameters(&self) -> PyResult<Vec<PyTensor>> {
182 Ok(Vec::new())
183 }
184
185 fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
187 Ok(HashMap::new())
188 }
189
190 fn __repr__(&self) -> String {
192 let stride_str = if let Some(stride) = self.stride {
193 format!("stride={:?}", stride)
194 } else {
195 "stride=None".to_string()
196 };
197 format!(
198 "MaxPool2d(kernel_size={:?}, {}, padding={:?}, dilation={:?}, ceil_mode={}, return_indices={})",
199 self.kernel_size, stride_str, self.padding, self.dilation, self.ceil_mode, self.return_indices
200 )
201 }
202}
203
204#[pyclass(name = "AvgPool2d", extends = PyModule)]
206pub struct PyAvgPool2d {
207 kernel_size: (usize, usize),
208 stride: Option<(usize, usize)>,
209 padding: (usize, usize),
210 ceil_mode: bool,
211 count_include_pad: bool,
212 divisor_override: Option<usize>,
213}
214
215#[pymethods]
216impl PyAvgPool2d {
217 #[new]
218 fn new(
219 kernel_size: Py<PyAny>,
220 stride: Option<Py<PyAny>>,
221 padding: Option<Py<PyAny>>,
222 ceil_mode: Option<bool>,
223 count_include_pad: Option<bool>,
224 divisor_override: Option<usize>,
225 ) -> PyResult<(Self, PyModule)> {
226 let kernel_size = Python::attach(|py| -> PyResult<(usize, usize)> {
228 if let Ok(size) = kernel_size.extract::<usize>(py) {
229 Ok((size, size))
230 } else if let Ok(tuple) = kernel_size.extract::<(usize, usize)>(py) {
231 Ok(tuple)
232 } else {
233 Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
234 "kernel_size must be an integer or tuple of integers",
235 ))
236 }
237 })?;
238
239 let stride = if let Some(stride_obj) = stride {
241 Some(Python::attach(|py| -> PyResult<(usize, usize)> {
242 if let Ok(stride) = stride_obj.extract::<usize>(py) {
243 Ok((stride, stride))
244 } else if let Ok(tuple) = stride_obj.extract::<(usize, usize)>(py) {
245 Ok(tuple)
246 } else {
247 Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
248 "stride must be an integer or tuple of integers",
249 ))
250 }
251 })?)
252 } else {
253 None
254 };
255
256 let padding = if let Some(padding_obj) = padding {
258 Python::attach(|py| -> PyResult<(usize, usize)> {
259 if let Ok(padding) = padding_obj.extract::<usize>(py) {
260 Ok((padding, padding))
261 } else if let Ok(tuple) = padding_obj.extract::<(usize, usize)>(py) {
262 Ok(tuple)
263 } else {
264 Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
265 "padding must be an integer or tuple of integers",
266 ))
267 }
268 })?
269 } else {
270 (0, 0)
271 };
272
273 Ok((
274 Self {
275 kernel_size,
276 stride,
277 padding,
278 ceil_mode: ceil_mode.unwrap_or(false),
279 count_include_pad: count_include_pad.unwrap_or(true),
280 divisor_override,
281 },
282 PyModule::new(),
283 ))
284 }
285
286 fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
288 let shape = input.tensor.shape().dims().to_vec();
290
291 if shape.len() != 4 {
293 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
294 "Expected 4D input (NCHW), got {}D",
295 shape.len()
296 )));
297 }
298
299 let (batch_size, channels, in_h, in_w) = (shape[0], shape[1], shape[2], shape[3]);
300 let (kh, kw) = self.kernel_size;
301 let (stride_h, stride_w) = self.stride.unwrap_or(self.kernel_size);
302 let (pad_h, pad_w) = self.padding;
303
304 let out_h = if self.ceil_mode {
306 ((in_h + 2 * pad_h - kh) as f32 / stride_h as f32).ceil() as usize + 1
307 } else {
308 (in_h + 2 * pad_h - kh) / stride_h + 1
309 };
310 let out_w = if self.ceil_mode {
311 ((in_w + 2 * pad_w - kw) as f32 / stride_w as f32).ceil() as usize + 1
312 } else {
313 (in_w + 2 * pad_w - kw) / stride_w + 1
314 };
315
316 let input_data = py_result!(input.tensor.data())?;
317 let mut output_data = vec![0.0; batch_size * channels * out_h * out_w];
318
319 for b in 0..batch_size {
321 for c in 0..channels {
322 for oh in 0..out_h {
323 for ow in 0..out_w {
324 let mut sum = 0.0;
325 let mut count = 0;
326
327 for kh_idx in 0..kh {
328 for kw_idx in 0..kw {
329 let ih = (oh * stride_h + kh_idx) as i32 - pad_h as i32;
330 let iw = (ow * stride_w + kw_idx) as i32 - pad_w as i32;
331
332 if ih >= 0 && ih < in_h as i32 && iw >= 0 && iw < in_w as i32 {
333 let input_idx = b * channels * in_h * in_w
334 + c * in_h * in_w
335 + ih as usize * in_w
336 + iw as usize;
337 sum += input_data[input_idx];
338 count += 1;
339 } else if self.count_include_pad {
340 count += 1;
341 }
342 }
343 }
344
345 let divisor = if let Some(div) = self.divisor_override {
346 div as f32
347 } else {
348 count as f32
349 };
350
351 let output_idx =
352 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
353 output_data[output_idx] = if divisor > 0.0 { sum / divisor } else { 0.0 };
354 }
355 }
356 }
357 }
358
359 let result = py_result!(torsh_tensor::Tensor::from_data(
360 output_data,
361 vec![batch_size, channels, out_h, out_w],
362 input.tensor.device()
363 ))?;
364
365 Ok(PyTensor { tensor: result })
366 }
367
368 fn parameters(&self) -> PyResult<Vec<PyTensor>> {
370 Ok(Vec::new())
371 }
372
373 fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
375 Ok(HashMap::new())
376 }
377
378 fn __repr__(&self) -> String {
380 let stride_str = if let Some(stride) = self.stride {
381 format!("stride={:?}", stride)
382 } else {
383 "stride=None".to_string()
384 };
385 let divisor_str = if let Some(divisor) = self.divisor_override {
386 format!("divisor_override={}", divisor)
387 } else {
388 "divisor_override=None".to_string()
389 };
390 format!(
391 "AvgPool2d(kernel_size={:?}, {}, padding={:?}, ceil_mode={}, count_include_pad={}, {})",
392 self.kernel_size,
393 stride_str,
394 self.padding,
395 self.ceil_mode,
396 self.count_include_pad,
397 divisor_str
398 )
399 }
400}
401
402#[pyclass(name = "AdaptiveAvgPool2d", extends = PyModule)]
404pub struct PyAdaptiveAvgPool2d {
405 output_size: (usize, usize),
406}
407
408#[pymethods]
409impl PyAdaptiveAvgPool2d {
410 #[new]
411 fn new(output_size: Py<PyAny>) -> PyResult<(Self, PyModule)> {
412 let output_size = Python::attach(|py| -> PyResult<(usize, usize)> {
414 if let Ok(size) = output_size.extract::<usize>(py) {
415 Ok((size, size))
416 } else if let Ok(tuple) = output_size.extract::<(usize, usize)>(py) {
417 Ok(tuple)
418 } else {
419 Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
420 "output_size must be an integer or tuple of integers",
421 ))
422 }
423 })?;
424
425 Ok((Self { output_size }, PyModule::new()))
426 }
427
428 fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
430 let shape = input.tensor.shape().dims().to_vec();
432
433 if shape.len() != 4 {
435 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
436 "Expected 4D input (NCHW), got {}D",
437 shape.len()
438 )));
439 }
440
441 let (batch_size, channels, in_h, in_w) = (shape[0], shape[1], shape[2], shape[3]);
442 let (out_h, out_w) = self.output_size;
443
444 let input_data = py_result!(input.tensor.data())?;
445 let mut output_data = vec![0.0; batch_size * channels * out_h * out_w];
446
447 for b in 0..batch_size {
449 for c in 0..channels {
450 for oh in 0..out_h {
451 for ow in 0..out_w {
452 let start_h = (oh * in_h) / out_h;
454 let end_h = ((oh + 1) * in_h) / out_h;
455 let start_w = (ow * in_w) / out_w;
456 let end_w = ((ow + 1) * in_w) / out_w;
457
458 let mut sum = 0.0;
459 let mut count = 0;
460
461 for ih in start_h..end_h {
462 for iw in start_w..end_w {
463 let input_idx =
464 b * channels * in_h * in_w + c * in_h * in_w + ih * in_w + iw;
465 sum += input_data[input_idx];
466 count += 1;
467 }
468 }
469
470 let output_idx =
471 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
472 output_data[output_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
473 }
474 }
475 }
476 }
477
478 let result = py_result!(torsh_tensor::Tensor::from_data(
479 output_data,
480 vec![batch_size, channels, out_h, out_w],
481 input.tensor.device()
482 ))?;
483
484 Ok(PyTensor { tensor: result })
485 }
486
487 fn parameters(&self) -> PyResult<Vec<PyTensor>> {
489 Ok(Vec::new())
490 }
491
492 fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
494 Ok(HashMap::new())
495 }
496
497 fn __repr__(&self) -> String {
499 format!("AdaptiveAvgPool2d(output_size={:?})", self.output_size)
500 }
501}
502
503#[pyclass(name = "AdaptiveMaxPool2d", extends = PyModule)]
505pub struct PyAdaptiveMaxPool2d {
506 output_size: (usize, usize),
507 return_indices: bool,
508}
509
510#[pymethods]
511impl PyAdaptiveMaxPool2d {
512 #[new]
513 fn new(output_size: Py<PyAny>, return_indices: Option<bool>) -> PyResult<(Self, PyModule)> {
514 let output_size = Python::attach(|py| -> PyResult<(usize, usize)> {
516 if let Ok(size) = output_size.extract::<usize>(py) {
517 Ok((size, size))
518 } else if let Ok(tuple) = output_size.extract::<(usize, usize)>(py) {
519 Ok(tuple)
520 } else {
521 Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
522 "output_size must be an integer or tuple of integers",
523 ))
524 }
525 })?;
526
527 Ok((
528 Self {
529 output_size,
530 return_indices: return_indices.unwrap_or(false),
531 },
532 PyModule::new(),
533 ))
534 }
535
536 fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
538 let shape = input.tensor.shape().dims().to_vec();
540
541 if shape.len() != 4 {
543 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
544 "Expected 4D input (NCHW), got {}D",
545 shape.len()
546 )));
547 }
548
549 let (batch_size, channels, in_h, in_w) = (shape[0], shape[1], shape[2], shape[3]);
550 let (out_h, out_w) = self.output_size;
551
552 let input_data = py_result!(input.tensor.data())?;
553 let mut output_data = vec![f32::NEG_INFINITY; batch_size * channels * out_h * out_w];
554
555 for b in 0..batch_size {
557 for c in 0..channels {
558 for oh in 0..out_h {
559 for ow in 0..out_w {
560 let start_h = (oh * in_h) / out_h;
562 let end_h = ((oh + 1) * in_h) / out_h;
563 let start_w = (ow * in_w) / out_w;
564 let end_w = ((ow + 1) * in_w) / out_w;
565
566 let mut max_val = f32::NEG_INFINITY;
567
568 for ih in start_h..end_h {
569 for iw in start_w..end_w {
570 let input_idx =
571 b * channels * in_h * in_w + c * in_h * in_w + ih * in_w + iw;
572 max_val = max_val.max(input_data[input_idx]);
573 }
574 }
575
576 let output_idx =
577 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
578 output_data[output_idx] = max_val;
579 }
580 }
581 }
582 }
583
584 let result = py_result!(torsh_tensor::Tensor::from_data(
585 output_data,
586 vec![batch_size, channels, out_h, out_w],
587 input.tensor.device()
588 ))?;
589
590 Ok(PyTensor { tensor: result })
591 }
592
593 fn parameters(&self) -> PyResult<Vec<PyTensor>> {
595 Ok(Vec::new())
596 }
597
598 fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
600 Ok(HashMap::new())
601 }
602
603 fn __repr__(&self) -> String {
605 format!(
606 "AdaptiveMaxPool2d(output_size={:?}, return_indices={})",
607 self.output_size, self.return_indices
608 )
609 }
610}