Skip to main content

yscv_model/layers/
pool.rs

1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use crate::ModelError;
5
6/// 2D max-pooling layer (NHWC layout).
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub struct MaxPool2dLayer {
9    kernel_h: usize,
10    kernel_w: usize,
11    stride_h: usize,
12    stride_w: usize,
13}
14
15impl MaxPool2dLayer {
16    pub fn new(
17        kernel_h: usize,
18        kernel_w: usize,
19        stride_h: usize,
20        stride_w: usize,
21    ) -> Result<Self, ModelError> {
22        if kernel_h == 0 || kernel_w == 0 {
23            return Err(ModelError::InvalidPoolKernel { kernel_h, kernel_w });
24        }
25        if stride_h == 0 || stride_w == 0 {
26            return Err(ModelError::InvalidPoolStride { stride_h, stride_w });
27        }
28        Ok(Self {
29            kernel_h,
30            kernel_w,
31            stride_h,
32            stride_w,
33        })
34    }
35
36    pub fn kernel_h(&self) -> usize {
37        self.kernel_h
38    }
39    pub fn kernel_w(&self) -> usize {
40        self.kernel_w
41    }
42    pub fn stride_h(&self) -> usize {
43        self.stride_h
44    }
45    pub fn stride_w(&self) -> usize {
46        self.stride_w
47    }
48
49    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
50        graph
51            .max_pool2d_nhwc(
52                input,
53                self.kernel_h,
54                self.kernel_w,
55                self.stride_h,
56                self.stride_w,
57            )
58            .map_err(Into::into)
59    }
60
61    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
62        yscv_kernels::max_pool2d_nhwc(
63            input,
64            self.kernel_h,
65            self.kernel_w,
66            self.stride_h,
67            self.stride_w,
68        )
69        .map_err(Into::into)
70    }
71}
72
73/// 2D average-pooling layer (NHWC layout).
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub struct AvgPool2dLayer {
76    kernel_h: usize,
77    kernel_w: usize,
78    stride_h: usize,
79    stride_w: usize,
80}
81
82impl AvgPool2dLayer {
83    pub fn new(
84        kernel_h: usize,
85        kernel_w: usize,
86        stride_h: usize,
87        stride_w: usize,
88    ) -> Result<Self, ModelError> {
89        if kernel_h == 0 || kernel_w == 0 {
90            return Err(ModelError::InvalidPoolKernel { kernel_h, kernel_w });
91        }
92        if stride_h == 0 || stride_w == 0 {
93            return Err(ModelError::InvalidPoolStride { stride_h, stride_w });
94        }
95        Ok(Self {
96            kernel_h,
97            kernel_w,
98            stride_h,
99            stride_w,
100        })
101    }
102
103    pub fn kernel_h(&self) -> usize {
104        self.kernel_h
105    }
106    pub fn kernel_w(&self) -> usize {
107        self.kernel_w
108    }
109    pub fn stride_h(&self) -> usize {
110        self.stride_h
111    }
112    pub fn stride_w(&self) -> usize {
113        self.stride_w
114    }
115
116    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
117        graph
118            .avg_pool2d_nhwc(
119                input,
120                self.kernel_h,
121                self.kernel_w,
122                self.stride_h,
123                self.stride_w,
124            )
125            .map_err(Into::into)
126    }
127
128    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
129        yscv_kernels::avg_pool2d_nhwc(
130            input,
131            self.kernel_h,
132            self.kernel_w,
133            self.stride_h,
134            self.stride_w,
135        )
136        .map_err(Into::into)
137    }
138}
139
140/// Global average pooling: NHWC `[N,H,W,C]` -> `[N,1,1,C]`.
141#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
142pub struct GlobalAvgPool2dLayer;
143
144impl GlobalAvgPool2dLayer {
145    pub fn new() -> Self {
146        Self
147    }
148
149    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
150        let shape = input.shape();
151        if shape.len() != 4 {
152            return Err(ModelError::InvalidFlattenShape {
153                got: shape.to_vec(),
154            });
155        }
156        let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
157        let hw = (h * w) as f32;
158        let data = input.data();
159        let mut out = vec![0.0f32; n * c];
160        for batch in 0..n {
161            for ch in 0..c {
162                let mut sum = 0.0f32;
163                for y in 0..h {
164                    for x in 0..w {
165                        sum += data[((batch * h + y) * w + x) * c + ch];
166                    }
167                }
168                out[batch * c + ch] = sum / hw;
169            }
170        }
171        Tensor::from_vec(vec![n, 1, 1, c], out).map_err(Into::into)
172    }
173
174    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
175        let shape = graph.value(input)?.shape().to_vec();
176        if shape.len() != 4 {
177            return Err(ModelError::InvalidFlattenShape { got: shape });
178        }
179        let (h, w) = (shape[1], shape[2]);
180        graph.avg_pool2d_nhwc(input, h, w, 1, 1).map_err(Into::into)
181    }
182}
183
184/// Adaptive average pooling: output a fixed spatial size regardless of input size.
185///
186/// NHWC layout: `[batch, H, W, C]` -> `[batch, out_h, out_w, C]`.
187#[derive(Debug, Clone, Copy, PartialEq, Eq)]
188pub struct AdaptiveAvgPool2dLayer {
189    out_h: usize,
190    out_w: usize,
191}
192
193impl AdaptiveAvgPool2dLayer {
194    pub fn new(out_h: usize, out_w: usize) -> Self {
195        Self { out_h, out_w }
196    }
197
198    pub fn output_h(&self) -> usize {
199        self.out_h
200    }
201    pub fn output_w(&self) -> usize {
202        self.out_w
203    }
204
205    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
206        graph
207            .adaptive_avg_pool2d_nhwc(input, self.out_h, self.out_w)
208            .map_err(Into::into)
209    }
210
211    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
212        let shape = input.shape();
213        if shape.len() != 4 {
214            return Err(ModelError::InvalidInputShape {
215                expected_features: 0,
216                got: shape.to_vec(),
217            });
218        }
219        let (batch, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
220        let data = input.data();
221        let mut out = vec![0.0f32; batch * self.out_h * self.out_w * c];
222
223        for b in 0..batch {
224            for oh in 0..self.out_h {
225                let h_start = oh * h / self.out_h;
226                let h_end = ((oh + 1) * h / self.out_h).max(h_start + 1);
227                for ow in 0..self.out_w {
228                    let w_start = ow * w / self.out_w;
229                    let w_end = ((ow + 1) * w / self.out_w).max(w_start + 1);
230                    let count = (h_end - h_start) * (w_end - w_start);
231                    for ch in 0..c {
232                        let mut sum = 0.0f32;
233                        for ih in h_start..h_end {
234                            for iw in w_start..w_end {
235                                sum += data[((b * h + ih) * w + iw) * c + ch];
236                            }
237                        }
238                        out[((b * self.out_h + oh) * self.out_w + ow) * c + ch] =
239                            sum / count as f32;
240                    }
241                }
242            }
243        }
244        Ok(Tensor::from_vec(
245            vec![batch, self.out_h, self.out_w, c],
246            out,
247        )?)
248    }
249}
250
251/// Adaptive max pooling: output a fixed spatial size.
252#[derive(Debug, Clone, Copy, PartialEq, Eq)]
253pub struct AdaptiveMaxPool2dLayer {
254    out_h: usize,
255    out_w: usize,
256}
257
258impl AdaptiveMaxPool2dLayer {
259    pub fn new(out_h: usize, out_w: usize) -> Self {
260        Self { out_h, out_w }
261    }
262
263    pub fn output_h(&self) -> usize {
264        self.out_h
265    }
266    pub fn output_w(&self) -> usize {
267        self.out_w
268    }
269
270    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
271        graph
272            .adaptive_max_pool2d_nhwc(input, self.out_h, self.out_w)
273            .map_err(Into::into)
274    }
275
276    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
277        let shape = input.shape();
278        if shape.len() != 4 {
279            return Err(ModelError::InvalidInputShape {
280                expected_features: 0,
281                got: shape.to_vec(),
282            });
283        }
284        let (batch, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
285        let data = input.data();
286        let mut out = vec![f32::NEG_INFINITY; batch * self.out_h * self.out_w * c];
287
288        for b in 0..batch {
289            for oh in 0..self.out_h {
290                let h_start = oh * h / self.out_h;
291                let h_end = ((oh + 1) * h / self.out_h).max(h_start + 1);
292                for ow in 0..self.out_w {
293                    let w_start = ow * w / self.out_w;
294                    let w_end = ((ow + 1) * w / self.out_w).max(w_start + 1);
295                    for ch in 0..c {
296                        let mut max_v = f32::NEG_INFINITY;
297                        for ih in h_start..h_end {
298                            for iw in w_start..w_end {
299                                let v = data[((b * h + ih) * w + iw) * c + ch];
300                                if v > max_v {
301                                    max_v = v;
302                                }
303                            }
304                        }
305                        out[((b * self.out_h + oh) * self.out_w + ow) * c + ch] = max_v;
306                    }
307                }
308            }
309        }
310        Ok(Tensor::from_vec(
311            vec![batch, self.out_h, self.out_w, c],
312            out,
313        )?)
314    }
315}