1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use crate::ModelError;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub struct MaxPool2dLayer {
9 kernel_h: usize,
10 kernel_w: usize,
11 stride_h: usize,
12 stride_w: usize,
13}
14
15impl MaxPool2dLayer {
16 pub fn new(
17 kernel_h: usize,
18 kernel_w: usize,
19 stride_h: usize,
20 stride_w: usize,
21 ) -> Result<Self, ModelError> {
22 if kernel_h == 0 || kernel_w == 0 {
23 return Err(ModelError::InvalidPoolKernel { kernel_h, kernel_w });
24 }
25 if stride_h == 0 || stride_w == 0 {
26 return Err(ModelError::InvalidPoolStride { stride_h, stride_w });
27 }
28 Ok(Self {
29 kernel_h,
30 kernel_w,
31 stride_h,
32 stride_w,
33 })
34 }
35
36 pub fn kernel_h(&self) -> usize {
37 self.kernel_h
38 }
39 pub fn kernel_w(&self) -> usize {
40 self.kernel_w
41 }
42 pub fn stride_h(&self) -> usize {
43 self.stride_h
44 }
45 pub fn stride_w(&self) -> usize {
46 self.stride_w
47 }
48
49 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
50 graph
51 .max_pool2d_nhwc(
52 input,
53 self.kernel_h,
54 self.kernel_w,
55 self.stride_h,
56 self.stride_w,
57 )
58 .map_err(Into::into)
59 }
60
61 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
62 yscv_kernels::max_pool2d_nhwc(
63 input,
64 self.kernel_h,
65 self.kernel_w,
66 self.stride_h,
67 self.stride_w,
68 )
69 .map_err(Into::into)
70 }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub struct AvgPool2dLayer {
76 kernel_h: usize,
77 kernel_w: usize,
78 stride_h: usize,
79 stride_w: usize,
80}
81
82impl AvgPool2dLayer {
83 pub fn new(
84 kernel_h: usize,
85 kernel_w: usize,
86 stride_h: usize,
87 stride_w: usize,
88 ) -> Result<Self, ModelError> {
89 if kernel_h == 0 || kernel_w == 0 {
90 return Err(ModelError::InvalidPoolKernel { kernel_h, kernel_w });
91 }
92 if stride_h == 0 || stride_w == 0 {
93 return Err(ModelError::InvalidPoolStride { stride_h, stride_w });
94 }
95 Ok(Self {
96 kernel_h,
97 kernel_w,
98 stride_h,
99 stride_w,
100 })
101 }
102
103 pub fn kernel_h(&self) -> usize {
104 self.kernel_h
105 }
106 pub fn kernel_w(&self) -> usize {
107 self.kernel_w
108 }
109 pub fn stride_h(&self) -> usize {
110 self.stride_h
111 }
112 pub fn stride_w(&self) -> usize {
113 self.stride_w
114 }
115
116 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
117 graph
118 .avg_pool2d_nhwc(
119 input,
120 self.kernel_h,
121 self.kernel_w,
122 self.stride_h,
123 self.stride_w,
124 )
125 .map_err(Into::into)
126 }
127
128 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
129 yscv_kernels::avg_pool2d_nhwc(
130 input,
131 self.kernel_h,
132 self.kernel_w,
133 self.stride_h,
134 self.stride_w,
135 )
136 .map_err(Into::into)
137 }
138}
139
140#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
142pub struct GlobalAvgPool2dLayer;
143
144impl GlobalAvgPool2dLayer {
145 pub fn new() -> Self {
146 Self
147 }
148
149 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
150 let shape = input.shape();
151 if shape.len() != 4 {
152 return Err(ModelError::InvalidFlattenShape {
153 got: shape.to_vec(),
154 });
155 }
156 let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
157 let hw = (h * w) as f32;
158 let data = input.data();
159 let mut out = vec![0.0f32; n * c];
160 for batch in 0..n {
161 for ch in 0..c {
162 let mut sum = 0.0f32;
163 for y in 0..h {
164 for x in 0..w {
165 sum += data[((batch * h + y) * w + x) * c + ch];
166 }
167 }
168 out[batch * c + ch] = sum / hw;
169 }
170 }
171 Tensor::from_vec(vec![n, 1, 1, c], out).map_err(Into::into)
172 }
173
174 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
175 let shape = graph.value(input)?.shape().to_vec();
176 if shape.len() != 4 {
177 return Err(ModelError::InvalidFlattenShape { got: shape });
178 }
179 let (h, w) = (shape[1], shape[2]);
180 graph.avg_pool2d_nhwc(input, h, w, 1, 1).map_err(Into::into)
181 }
182}
183
184#[derive(Debug, Clone, Copy, PartialEq, Eq)]
188pub struct AdaptiveAvgPool2dLayer {
189 out_h: usize,
190 out_w: usize,
191}
192
193impl AdaptiveAvgPool2dLayer {
194 pub fn new(out_h: usize, out_w: usize) -> Self {
195 Self { out_h, out_w }
196 }
197
198 pub fn output_h(&self) -> usize {
199 self.out_h
200 }
201 pub fn output_w(&self) -> usize {
202 self.out_w
203 }
204
205 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
206 graph
207 .adaptive_avg_pool2d_nhwc(input, self.out_h, self.out_w)
208 .map_err(Into::into)
209 }
210
211 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
212 let shape = input.shape();
213 if shape.len() != 4 {
214 return Err(ModelError::InvalidInputShape {
215 expected_features: 0,
216 got: shape.to_vec(),
217 });
218 }
219 let (batch, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
220 let data = input.data();
221 let mut out = vec![0.0f32; batch * self.out_h * self.out_w * c];
222
223 for b in 0..batch {
224 for oh in 0..self.out_h {
225 let h_start = oh * h / self.out_h;
226 let h_end = ((oh + 1) * h / self.out_h).max(h_start + 1);
227 for ow in 0..self.out_w {
228 let w_start = ow * w / self.out_w;
229 let w_end = ((ow + 1) * w / self.out_w).max(w_start + 1);
230 let count = (h_end - h_start) * (w_end - w_start);
231 for ch in 0..c {
232 let mut sum = 0.0f32;
233 for ih in h_start..h_end {
234 for iw in w_start..w_end {
235 sum += data[((b * h + ih) * w + iw) * c + ch];
236 }
237 }
238 out[((b * self.out_h + oh) * self.out_w + ow) * c + ch] =
239 sum / count as f32;
240 }
241 }
242 }
243 }
244 Ok(Tensor::from_vec(
245 vec![batch, self.out_h, self.out_w, c],
246 out,
247 )?)
248 }
249}
250
251#[derive(Debug, Clone, Copy, PartialEq, Eq)]
253pub struct AdaptiveMaxPool2dLayer {
254 out_h: usize,
255 out_w: usize,
256}
257
258impl AdaptiveMaxPool2dLayer {
259 pub fn new(out_h: usize, out_w: usize) -> Self {
260 Self { out_h, out_w }
261 }
262
263 pub fn output_h(&self) -> usize {
264 self.out_h
265 }
266 pub fn output_w(&self) -> usize {
267 self.out_w
268 }
269
270 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
271 graph
272 .adaptive_max_pool2d_nhwc(input, self.out_h, self.out_w)
273 .map_err(Into::into)
274 }
275
276 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
277 let shape = input.shape();
278 if shape.len() != 4 {
279 return Err(ModelError::InvalidInputShape {
280 expected_features: 0,
281 got: shape.to_vec(),
282 });
283 }
284 let (batch, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
285 let data = input.data();
286 let mut out = vec![f32::NEG_INFINITY; batch * self.out_h * self.out_w * c];
287
288 for b in 0..batch {
289 for oh in 0..self.out_h {
290 let h_start = oh * h / self.out_h;
291 let h_end = ((oh + 1) * h / self.out_h).max(h_start + 1);
292 for ow in 0..self.out_w {
293 let w_start = ow * w / self.out_w;
294 let w_end = ((ow + 1) * w / self.out_w).max(w_start + 1);
295 for ch in 0..c {
296 let mut max_v = f32::NEG_INFINITY;
297 for ih in h_start..h_end {
298 for iw in w_start..w_end {
299 let v = data[((b * h + ih) * w + iw) * c + ch];
300 if v > max_v {
301 max_v = v;
302 }
303 }
304 }
305 out[((b * self.out_h + oh) * self.out_w + ow) * c + ch] = max_v;
306 }
307 }
308 }
309 }
310 Ok(Tensor::from_vec(
311 vec![batch, self.out_h, self.out_w, c],
312 out,
313 )?)
314 }
315}