zenu_layer/layers/
max_pool_2d.rs1use std::collections::HashMap;
2
3use zenu_autograd::{
4 nn::pool2d::{max_pool_2d, MaxPool2dConfig},
5 Variable,
6};
7use zenu_matrix::{device::Device, num::Num};
8
9use crate::{Module, Parameters};
10
11pub struct MaxPool2d<T: Num> {
12 stride: (usize, usize),
13 kernel_size: (usize, usize),
14 pad: (usize, usize),
15 config: MaxPool2dConfig<T>,
16}
17
18impl<T: Num, D: Device> Parameters<T, D> for MaxPool2d<T> {
19 fn weights(&self) -> HashMap<String, Variable<T, D>> {
20 HashMap::new()
21 }
22
23 fn biases(&self) -> HashMap<String, Variable<T, D>> {
24 HashMap::new()
25 }
26}
27
28impl<T: Num> MaxPool2d<T> {
29 #[must_use]
30 pub fn new(kernel_size: (usize, usize), stride: (usize, usize), pad: (usize, usize)) -> Self {
31 Self {
32 stride,
33 kernel_size,
34 pad,
35 config: MaxPool2dConfig::default(),
36 }
37 }
38}
39
40impl<T: Num, D: Device> Module<T, D> for MaxPool2d<T> {
41 type Input = Variable<T, D>;
42 type Output = Variable<T, D>;
43 fn call(&self, input: Variable<T, D>) -> Variable<T, D> {
44 max_pool_2d(
45 input,
46 self.kernel_size,
47 self.stride,
48 self.pad,
49 self.config.clone(),
50 )
51 }
52}