Skip to main content

torsh_python/nn/
pooling.rs

1//! Pooling layers
2
3use super::module::PyModule;
4use crate::{error::PyResult, py_result, tensor::PyTensor};
5use pyo3::prelude::*;
6use pyo3::types::PyAny;
7use std::collections::HashMap;
8
9/// 2D Max Pooling layer
10#[pyclass(name = "MaxPool2d", extends = PyModule)]
11pub struct PyMaxPool2d {
12    kernel_size: (usize, usize),
13    stride: Option<(usize, usize)>,
14    padding: (usize, usize),
15    dilation: (usize, usize),
16    ceil_mode: bool,
17    return_indices: bool,
18}
19
20#[pymethods]
21impl PyMaxPool2d {
22    #[new]
23    fn new(
24        kernel_size: Py<PyAny>,
25        stride: Option<Py<PyAny>>,
26        padding: Option<Py<PyAny>>,
27        dilation: Option<Py<PyAny>>,
28        ceil_mode: Option<bool>,
29        return_indices: Option<bool>,
30    ) -> PyResult<(Self, PyModule)> {
31        // Parse kernel size
32        let kernel_size = Python::attach(|py| -> PyResult<(usize, usize)> {
33            if let Ok(size) = kernel_size.extract::<usize>(py) {
34                Ok((size, size))
35            } else if let Ok(tuple) = kernel_size.extract::<(usize, usize)>(py) {
36                Ok(tuple)
37            } else {
38                Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
39                    "kernel_size must be an integer or tuple of integers",
40                ))
41            }
42        })?;
43
44        // Parse stride (defaults to kernel_size if None)
45        let stride = if let Some(stride_obj) = stride {
46            Some(Python::attach(|py| -> PyResult<(usize, usize)> {
47                if let Ok(stride) = stride_obj.extract::<usize>(py) {
48                    Ok((stride, stride))
49                } else if let Ok(tuple) = stride_obj.extract::<(usize, usize)>(py) {
50                    Ok(tuple)
51                } else {
52                    Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
53                        "stride must be an integer or tuple of integers",
54                    ))
55                }
56            })?)
57        } else {
58            None
59        };
60
61        // Parse padding
62        let padding = if let Some(padding_obj) = padding {
63            Python::attach(|py| -> PyResult<(usize, usize)> {
64                if let Ok(padding) = padding_obj.extract::<usize>(py) {
65                    Ok((padding, padding))
66                } else if let Ok(tuple) = padding_obj.extract::<(usize, usize)>(py) {
67                    Ok(tuple)
68                } else {
69                    Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
70                        "padding must be an integer or tuple of integers",
71                    ))
72                }
73            })?
74        } else {
75            (0, 0)
76        };
77
78        // Parse dilation
79        let dilation = if let Some(dilation_obj) = dilation {
80            Python::attach(|py| -> PyResult<(usize, usize)> {
81                if let Ok(dilation) = dilation_obj.extract::<usize>(py) {
82                    Ok((dilation, dilation))
83                } else if let Ok(tuple) = dilation_obj.extract::<(usize, usize)>(py) {
84                    Ok(tuple)
85                } else {
86                    Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
87                        "dilation must be an integer or tuple of integers",
88                    ))
89                }
90            })?
91        } else {
92            (1, 1)
93        };
94
95        Ok((
96            Self {
97                kernel_size,
98                stride,
99                padding,
100                dilation,
101                ceil_mode: ceil_mode.unwrap_or(false),
102                return_indices: return_indices.unwrap_or(false),
103            },
104            PyModule::new(),
105        ))
106    }
107
108    /// Forward pass through max pool 2d
109    fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
110        // ✅ Proper max pooling implementation
111        let shape = input.tensor.shape().dims().to_vec();
112
113        // Expect 4D input: (batch, channels, height, width)
114        if shape.len() != 4 {
115            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
116                "Expected 4D input (NCHW), got {}D",
117                shape.len()
118            )));
119        }
120
121        let (batch_size, channels, in_h, in_w) = (shape[0], shape[1], shape[2], shape[3]);
122        let (kh, kw) = self.kernel_size;
123        let (stride_h, stride_w) = self.stride.unwrap_or(self.kernel_size);
124        let (pad_h, pad_w) = self.padding;
125
126        // Calculate output dimensions
127        let out_h = if self.ceil_mode {
128            ((in_h + 2 * pad_h - kh) as f32 / stride_h as f32).ceil() as usize + 1
129        } else {
130            (in_h + 2 * pad_h - kh) / stride_h + 1
131        };
132        let out_w = if self.ceil_mode {
133            ((in_w + 2 * pad_w - kw) as f32 / stride_w as f32).ceil() as usize + 1
134        } else {
135            (in_w + 2 * pad_w - kw) / stride_w + 1
136        };
137
138        let input_data = py_result!(input.tensor.data())?;
139        let mut output_data = vec![f32::NEG_INFINITY; batch_size * channels * out_h * out_w];
140
141        // Perform max pooling
142        for b in 0..batch_size {
143            for c in 0..channels {
144                for oh in 0..out_h {
145                    for ow in 0..out_w {
146                        let mut max_val = f32::NEG_INFINITY;
147
148                        for kh_idx in 0..kh {
149                            for kw_idx in 0..kw {
150                                let ih = (oh * stride_h + kh_idx) as i32 - pad_h as i32;
151                                let iw = (ow * stride_w + kw_idx) as i32 - pad_w as i32;
152
153                                if ih >= 0 && ih < in_h as i32 && iw >= 0 && iw < in_w as i32 {
154                                    let input_idx = b * channels * in_h * in_w
155                                        + c * in_h * in_w
156                                        + ih as usize * in_w
157                                        + iw as usize;
158                                    max_val = max_val.max(input_data[input_idx]);
159                                }
160                            }
161                        }
162
163                        let output_idx =
164                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
165                        output_data[output_idx] = max_val;
166                    }
167                }
168            }
169        }
170
171        let result = py_result!(torsh_tensor::Tensor::from_data(
172            output_data,
173            vec![batch_size, channels, out_h, out_w],
174            input.tensor.device()
175        ))?;
176
177        Ok(PyTensor { tensor: result })
178    }
179
180    /// Get layer parameters (MaxPool2d has no parameters)
181    fn parameters(&self) -> PyResult<Vec<PyTensor>> {
182        Ok(Vec::new())
183    }
184
185    /// Get named parameters (MaxPool2d has no parameters)
186    fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
187        Ok(HashMap::new())
188    }
189
190    /// String representation
191    fn __repr__(&self) -> String {
192        let stride_str = if let Some(stride) = self.stride {
193            format!("stride={:?}", stride)
194        } else {
195            "stride=None".to_string()
196        };
197        format!(
198            "MaxPool2d(kernel_size={:?}, {}, padding={:?}, dilation={:?}, ceil_mode={}, return_indices={})",
199            self.kernel_size, stride_str, self.padding, self.dilation, self.ceil_mode, self.return_indices
200        )
201    }
202}
203
204/// 2D Average Pooling layer
205#[pyclass(name = "AvgPool2d", extends = PyModule)]
206pub struct PyAvgPool2d {
207    kernel_size: (usize, usize),
208    stride: Option<(usize, usize)>,
209    padding: (usize, usize),
210    ceil_mode: bool,
211    count_include_pad: bool,
212    divisor_override: Option<usize>,
213}
214
215#[pymethods]
216impl PyAvgPool2d {
217    #[new]
218    fn new(
219        kernel_size: Py<PyAny>,
220        stride: Option<Py<PyAny>>,
221        padding: Option<Py<PyAny>>,
222        ceil_mode: Option<bool>,
223        count_include_pad: Option<bool>,
224        divisor_override: Option<usize>,
225    ) -> PyResult<(Self, PyModule)> {
226        // Parse kernel size
227        let kernel_size = Python::attach(|py| -> PyResult<(usize, usize)> {
228            if let Ok(size) = kernel_size.extract::<usize>(py) {
229                Ok((size, size))
230            } else if let Ok(tuple) = kernel_size.extract::<(usize, usize)>(py) {
231                Ok(tuple)
232            } else {
233                Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
234                    "kernel_size must be an integer or tuple of integers",
235                ))
236            }
237        })?;
238
239        // Parse stride
240        let stride = if let Some(stride_obj) = stride {
241            Some(Python::attach(|py| -> PyResult<(usize, usize)> {
242                if let Ok(stride) = stride_obj.extract::<usize>(py) {
243                    Ok((stride, stride))
244                } else if let Ok(tuple) = stride_obj.extract::<(usize, usize)>(py) {
245                    Ok(tuple)
246                } else {
247                    Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
248                        "stride must be an integer or tuple of integers",
249                    ))
250                }
251            })?)
252        } else {
253            None
254        };
255
256        // Parse padding
257        let padding = if let Some(padding_obj) = padding {
258            Python::attach(|py| -> PyResult<(usize, usize)> {
259                if let Ok(padding) = padding_obj.extract::<usize>(py) {
260                    Ok((padding, padding))
261                } else if let Ok(tuple) = padding_obj.extract::<(usize, usize)>(py) {
262                    Ok(tuple)
263                } else {
264                    Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
265                        "padding must be an integer or tuple of integers",
266                    ))
267                }
268            })?
269        } else {
270            (0, 0)
271        };
272
273        Ok((
274            Self {
275                kernel_size,
276                stride,
277                padding,
278                ceil_mode: ceil_mode.unwrap_or(false),
279                count_include_pad: count_include_pad.unwrap_or(true),
280                divisor_override,
281            },
282            PyModule::new(),
283        ))
284    }
285
286    /// Forward pass through average pool 2d
287    fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
288        // ✅ Proper average pooling implementation
289        let shape = input.tensor.shape().dims().to_vec();
290
291        // Expect 4D input: (batch, channels, height, width)
292        if shape.len() != 4 {
293            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
294                "Expected 4D input (NCHW), got {}D",
295                shape.len()
296            )));
297        }
298
299        let (batch_size, channels, in_h, in_w) = (shape[0], shape[1], shape[2], shape[3]);
300        let (kh, kw) = self.kernel_size;
301        let (stride_h, stride_w) = self.stride.unwrap_or(self.kernel_size);
302        let (pad_h, pad_w) = self.padding;
303
304        // Calculate output dimensions
305        let out_h = if self.ceil_mode {
306            ((in_h + 2 * pad_h - kh) as f32 / stride_h as f32).ceil() as usize + 1
307        } else {
308            (in_h + 2 * pad_h - kh) / stride_h + 1
309        };
310        let out_w = if self.ceil_mode {
311            ((in_w + 2 * pad_w - kw) as f32 / stride_w as f32).ceil() as usize + 1
312        } else {
313            (in_w + 2 * pad_w - kw) / stride_w + 1
314        };
315
316        let input_data = py_result!(input.tensor.data())?;
317        let mut output_data = vec![0.0; batch_size * channels * out_h * out_w];
318
319        // Perform average pooling
320        for b in 0..batch_size {
321            for c in 0..channels {
322                for oh in 0..out_h {
323                    for ow in 0..out_w {
324                        let mut sum = 0.0;
325                        let mut count = 0;
326
327                        for kh_idx in 0..kh {
328                            for kw_idx in 0..kw {
329                                let ih = (oh * stride_h + kh_idx) as i32 - pad_h as i32;
330                                let iw = (ow * stride_w + kw_idx) as i32 - pad_w as i32;
331
332                                if ih >= 0 && ih < in_h as i32 && iw >= 0 && iw < in_w as i32 {
333                                    let input_idx = b * channels * in_h * in_w
334                                        + c * in_h * in_w
335                                        + ih as usize * in_w
336                                        + iw as usize;
337                                    sum += input_data[input_idx];
338                                    count += 1;
339                                } else if self.count_include_pad {
340                                    count += 1;
341                                }
342                            }
343                        }
344
345                        let divisor = if let Some(div) = self.divisor_override {
346                            div as f32
347                        } else {
348                            count as f32
349                        };
350
351                        let output_idx =
352                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
353                        output_data[output_idx] = if divisor > 0.0 { sum / divisor } else { 0.0 };
354                    }
355                }
356            }
357        }
358
359        let result = py_result!(torsh_tensor::Tensor::from_data(
360            output_data,
361            vec![batch_size, channels, out_h, out_w],
362            input.tensor.device()
363        ))?;
364
365        Ok(PyTensor { tensor: result })
366    }
367
368    /// Get layer parameters (AvgPool2d has no parameters)
369    fn parameters(&self) -> PyResult<Vec<PyTensor>> {
370        Ok(Vec::new())
371    }
372
373    /// Get named parameters (AvgPool2d has no parameters)
374    fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
375        Ok(HashMap::new())
376    }
377
378    /// String representation
379    fn __repr__(&self) -> String {
380        let stride_str = if let Some(stride) = self.stride {
381            format!("stride={:?}", stride)
382        } else {
383            "stride=None".to_string()
384        };
385        let divisor_str = if let Some(divisor) = self.divisor_override {
386            format!("divisor_override={}", divisor)
387        } else {
388            "divisor_override=None".to_string()
389        };
390        format!(
391            "AvgPool2d(kernel_size={:?}, {}, padding={:?}, ceil_mode={}, count_include_pad={}, {})",
392            self.kernel_size,
393            stride_str,
394            self.padding,
395            self.ceil_mode,
396            self.count_include_pad,
397            divisor_str
398        )
399    }
400}
401
402/// Adaptive Average Pooling 2D layer
403#[pyclass(name = "AdaptiveAvgPool2d", extends = PyModule)]
404pub struct PyAdaptiveAvgPool2d {
405    output_size: (usize, usize),
406}
407
408#[pymethods]
409impl PyAdaptiveAvgPool2d {
410    #[new]
411    fn new(output_size: Py<PyAny>) -> PyResult<(Self, PyModule)> {
412        // Parse output size
413        let output_size = Python::attach(|py| -> PyResult<(usize, usize)> {
414            if let Ok(size) = output_size.extract::<usize>(py) {
415                Ok((size, size))
416            } else if let Ok(tuple) = output_size.extract::<(usize, usize)>(py) {
417                Ok(tuple)
418            } else {
419                Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
420                    "output_size must be an integer or tuple of integers",
421                ))
422            }
423        })?;
424
425        Ok((Self { output_size }, PyModule::new()))
426    }
427
428    /// Forward pass through adaptive average pool 2d
429    fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
430        // ✅ Proper adaptive average pooling implementation
431        let shape = input.tensor.shape().dims().to_vec();
432
433        // Expect 4D input: (batch, channels, height, width)
434        if shape.len() != 4 {
435            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
436                "Expected 4D input (NCHW), got {}D",
437                shape.len()
438            )));
439        }
440
441        let (batch_size, channels, in_h, in_w) = (shape[0], shape[1], shape[2], shape[3]);
442        let (out_h, out_w) = self.output_size;
443
444        let input_data = py_result!(input.tensor.data())?;
445        let mut output_data = vec![0.0; batch_size * channels * out_h * out_w];
446
447        // Perform adaptive average pooling
448        for b in 0..batch_size {
449            for c in 0..channels {
450                for oh in 0..out_h {
451                    for ow in 0..out_w {
452                        // Calculate adaptive pooling window
453                        let start_h = (oh * in_h) / out_h;
454                        let end_h = ((oh + 1) * in_h) / out_h;
455                        let start_w = (ow * in_w) / out_w;
456                        let end_w = ((ow + 1) * in_w) / out_w;
457
458                        let mut sum = 0.0;
459                        let mut count = 0;
460
461                        for ih in start_h..end_h {
462                            for iw in start_w..end_w {
463                                let input_idx =
464                                    b * channels * in_h * in_w + c * in_h * in_w + ih * in_w + iw;
465                                sum += input_data[input_idx];
466                                count += 1;
467                            }
468                        }
469
470                        let output_idx =
471                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
472                        output_data[output_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
473                    }
474                }
475            }
476        }
477
478        let result = py_result!(torsh_tensor::Tensor::from_data(
479            output_data,
480            vec![batch_size, channels, out_h, out_w],
481            input.tensor.device()
482        ))?;
483
484        Ok(PyTensor { tensor: result })
485    }
486
487    /// Get layer parameters (AdaptiveAvgPool2d has no parameters)
488    fn parameters(&self) -> PyResult<Vec<PyTensor>> {
489        Ok(Vec::new())
490    }
491
492    /// Get named parameters (AdaptiveAvgPool2d has no parameters)
493    fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
494        Ok(HashMap::new())
495    }
496
497    /// String representation
498    fn __repr__(&self) -> String {
499        format!("AdaptiveAvgPool2d(output_size={:?})", self.output_size)
500    }
501}
502
503/// Adaptive Max Pooling 2D layer
504#[pyclass(name = "AdaptiveMaxPool2d", extends = PyModule)]
505pub struct PyAdaptiveMaxPool2d {
506    output_size: (usize, usize),
507    return_indices: bool,
508}
509
510#[pymethods]
511impl PyAdaptiveMaxPool2d {
512    #[new]
513    fn new(output_size: Py<PyAny>, return_indices: Option<bool>) -> PyResult<(Self, PyModule)> {
514        // Parse output size
515        let output_size = Python::attach(|py| -> PyResult<(usize, usize)> {
516            if let Ok(size) = output_size.extract::<usize>(py) {
517                Ok((size, size))
518            } else if let Ok(tuple) = output_size.extract::<(usize, usize)>(py) {
519                Ok(tuple)
520            } else {
521                Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
522                    "output_size must be an integer or tuple of integers",
523                ))
524            }
525        })?;
526
527        Ok((
528            Self {
529                output_size,
530                return_indices: return_indices.unwrap_or(false),
531            },
532            PyModule::new(),
533        ))
534    }
535
536    /// Forward pass through adaptive max pool 2d
537    fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
538        // ✅ Proper adaptive max pooling implementation
539        let shape = input.tensor.shape().dims().to_vec();
540
541        // Expect 4D input: (batch, channels, height, width)
542        if shape.len() != 4 {
543            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
544                "Expected 4D input (NCHW), got {}D",
545                shape.len()
546            )));
547        }
548
549        let (batch_size, channels, in_h, in_w) = (shape[0], shape[1], shape[2], shape[3]);
550        let (out_h, out_w) = self.output_size;
551
552        let input_data = py_result!(input.tensor.data())?;
553        let mut output_data = vec![f32::NEG_INFINITY; batch_size * channels * out_h * out_w];
554
555        // Perform adaptive max pooling
556        for b in 0..batch_size {
557            for c in 0..channels {
558                for oh in 0..out_h {
559                    for ow in 0..out_w {
560                        // Calculate adaptive pooling window
561                        let start_h = (oh * in_h) / out_h;
562                        let end_h = ((oh + 1) * in_h) / out_h;
563                        let start_w = (ow * in_w) / out_w;
564                        let end_w = ((ow + 1) * in_w) / out_w;
565
566                        let mut max_val = f32::NEG_INFINITY;
567
568                        for ih in start_h..end_h {
569                            for iw in start_w..end_w {
570                                let input_idx =
571                                    b * channels * in_h * in_w + c * in_h * in_w + ih * in_w + iw;
572                                max_val = max_val.max(input_data[input_idx]);
573                            }
574                        }
575
576                        let output_idx =
577                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
578                        output_data[output_idx] = max_val;
579                    }
580                }
581            }
582        }
583
584        let result = py_result!(torsh_tensor::Tensor::from_data(
585            output_data,
586            vec![batch_size, channels, out_h, out_w],
587            input.tensor.device()
588        ))?;
589
590        Ok(PyTensor { tensor: result })
591    }
592
593    /// Get layer parameters (AdaptiveMaxPool2d has no parameters)
594    fn parameters(&self) -> PyResult<Vec<PyTensor>> {
595        Ok(Vec::new())
596    }
597
598    /// Get named parameters (AdaptiveMaxPool2d has no parameters)
599    fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
600        Ok(HashMap::new())
601    }
602
603    /// String representation
604    fn __repr__(&self) -> String {
605        format!(
606            "AdaptiveMaxPool2d(output_size={:?}, return_indices={})",
607            self.output_size, self.return_indices
608        )
609    }
610}