zenu_layer/layers/
max_pool_2d.rs

1use std::collections::HashMap;
2
3use zenu_autograd::{
4    nn::pool2d::{max_pool_2d, MaxPool2dConfig},
5    Variable,
6};
7use zenu_matrix::{device::Device, num::Num};
8
9use crate::{Module, Parameters};
10
11pub struct MaxPool2d<T: Num> {
12    stride: (usize, usize),
13    kernel_size: (usize, usize),
14    pad: (usize, usize),
15    config: MaxPool2dConfig<T>,
16}
17
18impl<T: Num, D: Device> Parameters<T, D> for MaxPool2d<T> {
19    fn weights(&self) -> HashMap<String, Variable<T, D>> {
20        HashMap::new()
21    }
22
23    fn biases(&self) -> HashMap<String, Variable<T, D>> {
24        HashMap::new()
25    }
26}
27
28impl<T: Num> MaxPool2d<T> {
29    #[must_use]
30    pub fn new(kernel_size: (usize, usize), stride: (usize, usize), pad: (usize, usize)) -> Self {
31        Self {
32            stride,
33            kernel_size,
34            pad,
35            config: MaxPool2dConfig::default(),
36        }
37    }
38}
39
40impl<T: Num, D: Device> Module<T, D> for MaxPool2d<T> {
41    type Input = Variable<T, D>;
42    type Output = Variable<T, D>;
43    fn call(&self, input: Variable<T, D>) -> Variable<T, D> {
44        max_pool_2d(
45            input,
46            self.kernel_size,
47            self.stride,
48            self.pad,
49            self.config.clone(),
50        )
51    }
52}