torsh_functional/pooling/
basic.rs1use crate::utils::{calculate_pooling_output_size, function_context, validate_tensor_dims};
4use torsh_core::Result as TorshResult;
5use torsh_tensor::Tensor;
6
7#[allow(clippy::too_many_arguments)]
9pub fn max_pool1d(
10 input: &Tensor,
11 kernel_size: usize,
12 stride: Option<usize>,
13 padding: usize,
14 dilation: usize,
15 return_indices: bool,
16) -> TorshResult<(Tensor, Option<Tensor>)> {
17 let stride = stride.unwrap_or(kernel_size);
18
19 let context = function_context("max_pool1d");
20 validate_tensor_dims(input, 3, &context)?;
21
22 let shape = input.shape();
23 let dims = shape.dims();
24 let batch_size = dims[0];
25 let channels = dims[1];
26 let length = dims[2];
27
28 let out_length = calculate_pooling_output_size(length, kernel_size, stride, padding, dilation);
29
30 let mut output_data = vec![f32::NEG_INFINITY; batch_size * channels * out_length];
31 let mut indices_data = if return_indices {
32 Some(vec![0i64; batch_size * channels * out_length])
33 } else {
34 None
35 };
36
37 let input_data = input.to_vec()?;
38
39 for b in 0..batch_size {
40 for c in 0..channels {
41 for ol in 0..out_length {
42 let out_idx = (b * channels + c) * out_length + ol;
43 let mut max_val = f32::NEG_INFINITY;
44 let mut max_idx = 0;
45
46 for kl in 0..kernel_size {
47 let il = ol * stride + kl * dilation;
48
49 if il >= padding && il < length + padding {
50 let real_il = il - padding;
51
52 if real_il < length {
53 let in_idx = (b * channels + c) * length + real_il;
54 let val = input_data[in_idx];
55
56 if val > max_val {
57 max_val = val;
58 max_idx = in_idx as i64;
59 }
60 }
61 }
62 }
63
64 output_data[out_idx] = max_val;
65 if let Some(ref mut indices) = indices_data {
66 indices[out_idx] = max_idx;
67 }
68 }
69 }
70 }
71
72 let output = Tensor::from_data(
73 output_data,
74 vec![batch_size, channels, out_length],
75 input.device(),
76 )?;
77
78 let indices = if let Some(indices_data) = indices_data {
79 let indices_f32: Vec<f32> = indices_data.iter().map(|&idx| idx as f32).collect();
80 Some(Tensor::from_data(
81 indices_f32,
82 vec![batch_size, channels, out_length],
83 input.device(),
84 )?)
85 } else {
86 None
87 };
88
89 Ok((output, indices))
90}
91
92#[allow(clippy::too_many_arguments)]
94pub fn max_pool2d(
95 input: &Tensor,
96 kernel_size: (usize, usize),
97 stride: Option<(usize, usize)>,
98 padding: (usize, usize),
99 dilation: (usize, usize),
100 ceil_mode: bool,
101 return_indices: bool,
102) -> TorshResult<(Tensor, Option<Tensor>)> {
103 let stride = stride.unwrap_or(kernel_size);
104
105 let context = function_context("max_pool2d");
106 validate_tensor_dims(input, 4, &context)?;
107
108 let shape = input.shape();
109 let dims = shape.dims();
110 let batch_size = dims[0];
111 let channels = dims[1];
112 let height = dims[2];
113 let width = dims[3];
114
115 let out_height = if ceil_mode {
116 ((height + 2 * padding.0 - dilation.0 * (kernel_size.0 - 1) - 1) as f32 / stride.0 as f32)
117 .ceil() as usize
118 } else {
119 calculate_pooling_output_size(height, kernel_size.0, stride.0, padding.0, dilation.0)
120 };
121
122 let out_width = if ceil_mode {
123 ((width + 2 * padding.1 - dilation.1 * (kernel_size.1 - 1) - 1) as f32 / stride.1 as f32)
124 .ceil() as usize
125 } else {
126 calculate_pooling_output_size(width, kernel_size.1, stride.1, padding.1, dilation.1)
127 };
128
129 let output_size = batch_size * channels * out_height * out_width;
130 let mut output_data = vec![f32::NEG_INFINITY; output_size];
131 let mut indices_data = if return_indices {
132 Some(vec![0i64; output_size])
133 } else {
134 None
135 };
136
137 let input_data = input.to_vec()?;
138
139 for b in 0..batch_size {
140 for c in 0..channels {
141 for oh in 0..out_height {
142 for ow in 0..out_width {
143 let out_idx = ((b * channels + c) * out_height + oh) * out_width + ow;
144 let mut max_val = f32::NEG_INFINITY;
145 let mut max_idx = 0;
146
147 for kh in 0..kernel_size.0 {
148 for kw in 0..kernel_size.1 {
149 let ih = oh * stride.0 + kh * dilation.0;
150 let iw = ow * stride.1 + kw * dilation.1;
151
152 if ih >= padding.0
153 && ih < height + padding.0
154 && iw >= padding.1
155 && iw < width + padding.1
156 {
157 let real_ih = ih - padding.0;
158 let real_iw = iw - padding.1;
159
160 if real_ih < height && real_iw < width {
161 let in_idx =
162 ((b * channels + c) * height + real_ih) * width + real_iw;
163 let val = input_data[in_idx];
164
165 if val > max_val {
166 max_val = val;
167 max_idx = in_idx as i64;
168 }
169 }
170 }
171 }
172 }
173
174 output_data[out_idx] = max_val;
175 if let Some(ref mut indices) = indices_data {
176 indices[out_idx] = max_idx;
177 }
178 }
179 }
180 }
181 }
182
183 let output = Tensor::from_data(
184 output_data,
185 vec![batch_size, channels, out_height, out_width],
186 input.device(),
187 )?;
188
189 let indices = if let Some(indices_data) = indices_data {
190 let indices_f32: Vec<f32> = indices_data.iter().map(|&idx| idx as f32).collect();
191 Some(Tensor::from_data(
192 indices_f32,
193 vec![batch_size, channels, out_height, out_width],
194 input.device(),
195 )?)
196 } else {
197 None
198 };
199
200 Ok((output, indices))
201}
202
203pub fn avg_pool1d(
205 input: &Tensor,
206 kernel_size: usize,
207 stride: Option<usize>,
208 padding: usize,
209 ceil_mode: bool,
210 count_include_pad: bool,
211) -> TorshResult<Tensor> {
212 let stride = stride.unwrap_or(kernel_size);
213
214 let context = function_context("avg_pool1d");
215 validate_tensor_dims(input, 3, &context)?;
216
217 let shape = input.shape();
218 let dims = shape.dims();
219 let batch_size = dims[0];
220 let channels = dims[1];
221 let length = dims[2];
222
223 let out_length = if ceil_mode {
224 ((length + 2 * padding - kernel_size) as f32 / stride as f32).ceil() as usize + 1
225 } else {
226 calculate_pooling_output_size(length, kernel_size, stride, padding, 1)
227 };
228
229 let mut output_data = vec![0.0f32; batch_size * channels * out_length];
230 let input_data = input.to_vec()?;
231
232 for b in 0..batch_size {
233 for c in 0..channels {
234 for ol in 0..out_length {
235 let out_idx = (b * channels + c) * out_length + ol;
236 let mut sum = 0.0f32;
237 let mut count = 0;
238
239 for kl in 0..kernel_size {
240 let il = ol * stride + kl;
241
242 if il >= padding && il < length + padding {
243 let real_il = il - padding;
244
245 if real_il < length {
246 let in_idx = (b * channels + c) * length + real_il;
247 sum += input_data[in_idx];
248 count += 1;
249 } else if count_include_pad {
250 count += 1;
251 }
252 } else if count_include_pad {
253 count += 1;
254 }
255 }
256
257 if count > 0 {
258 output_data[out_idx] = sum / count as f32;
259 }
260 }
261 }
262 }
263
264 Tensor::from_data(
265 output_data,
266 vec![batch_size, channels, out_length],
267 input.device(),
268 )
269}
270
271pub fn avg_pool2d(
273 input: &Tensor,
274 kernel_size: (usize, usize),
275 stride: Option<(usize, usize)>,
276 padding: (usize, usize),
277 ceil_mode: bool,
278 count_include_pad: bool,
279 divisor_override: Option<usize>,
280) -> TorshResult<Tensor> {
281 let stride = stride.unwrap_or(kernel_size);
282
283 let context = function_context("avg_pool2d");
284 validate_tensor_dims(input, 4, &context)?;
285
286 let shape = input.shape();
287 let dims = shape.dims();
288 let batch_size = dims[0];
289 let channels = dims[1];
290 let height = dims[2];
291 let width = dims[3];
292
293 let out_height = if ceil_mode {
294 ((height + 2 * padding.0 - kernel_size.0) as f32 / stride.0 as f32).ceil() as usize + 1
295 } else {
296 calculate_pooling_output_size(height, kernel_size.0, stride.0, padding.0, 1)
297 };
298
299 let out_width = if ceil_mode {
300 ((width + 2 * padding.1 - kernel_size.1) as f32 / stride.1 as f32).ceil() as usize + 1
301 } else {
302 calculate_pooling_output_size(width, kernel_size.1, stride.1, padding.1, 1)
303 };
304
305 let output_size = batch_size * channels * out_height * out_width;
306 let mut output_data = vec![0.0f32; output_size];
307 let input_data = input.to_vec()?;
308
309 for b in 0..batch_size {
310 for c in 0..channels {
311 for oh in 0..out_height {
312 for ow in 0..out_width {
313 let out_idx = ((b * channels + c) * out_height + oh) * out_width + ow;
314 let mut sum = 0.0f32;
315 let mut count = 0;
316
317 for kh in 0..kernel_size.0 {
318 for kw in 0..kernel_size.1 {
319 let ih = oh * stride.0 + kh;
320 let iw = ow * stride.1 + kw;
321
322 if ih >= padding.0
323 && ih < height + padding.0
324 && iw >= padding.1
325 && iw < width + padding.1
326 {
327 let real_ih = ih - padding.0;
328 let real_iw = iw - padding.1;
329
330 if real_ih < height && real_iw < width {
331 let in_idx =
332 ((b * channels + c) * height + real_ih) * width + real_iw;
333 sum += input_data[in_idx];
334 count += 1;
335 } else if count_include_pad {
336 count += 1;
337 }
338 } else if count_include_pad {
339 count += 1;
340 }
341 }
342 }
343
344 let divisor = divisor_override.unwrap_or(count) as f32;
345 if divisor > 0.0 {
346 output_data[out_idx] = sum / divisor;
347 }
348 }
349 }
350 }
351 }
352
353 Tensor::from_data(
354 output_data,
355 vec![batch_size, channels, out_height, out_width],
356 input.device(),
357 )
358}
359
360pub fn max_pool3d(
362 input: &Tensor,
363 kernel_size: (usize, usize, usize),
364 stride: Option<(usize, usize, usize)>,
365 padding: (usize, usize, usize),
366 dilation: (usize, usize, usize),
367 ceil_mode: bool,
368) -> TorshResult<Tensor> {
369 let stride = stride.unwrap_or(kernel_size);
370
371 let context = function_context("max_pool3d");
372 validate_tensor_dims(input, 5, &context)?;
373
374 let shape = input.shape();
375 let dims = shape.dims();
376 let batch_size = dims[0];
377 let channels = dims[1];
378 let depth = dims[2];
379 let height = dims[3];
380 let width = dims[4];
381
382 let effective_kernel = (
383 (kernel_size.0 - 1) * dilation.0 + 1,
384 (kernel_size.1 - 1) * dilation.1 + 1,
385 (kernel_size.2 - 1) * dilation.2 + 1,
386 );
387
388 let out_depth = if ceil_mode {
389 ((depth + 2 * padding.0 - effective_kernel.0) as f32 / stride.0 as f32).ceil() as usize + 1
390 } else {
391 (depth + 2 * padding.0 - effective_kernel.0) / stride.0 + 1
392 };
393
394 let out_height = if ceil_mode {
395 ((height + 2 * padding.1 - effective_kernel.1) as f32 / stride.1 as f32).ceil() as usize + 1
396 } else {
397 (height + 2 * padding.1 - effective_kernel.1) / stride.1 + 1
398 };
399
400 let out_width = if ceil_mode {
401 ((width + 2 * padding.2 - effective_kernel.2) as f32 / stride.2 as f32).ceil() as usize + 1
402 } else {
403 (width + 2 * padding.2 - effective_kernel.2) / stride.2 + 1
404 };
405
406 let output_size = batch_size * channels * out_depth * out_height * out_width;
407 let mut output_data = vec![f32::NEG_INFINITY; output_size];
408 let input_data = input.to_vec()?;
409
410 for b in 0..batch_size {
411 for c in 0..channels {
412 for od in 0..out_depth {
413 for oh in 0..out_height {
414 for ow in 0..out_width {
415 let out_idx = (((b * channels + c) * out_depth + od) * out_height + oh)
416 * out_width
417 + ow;
418
419 for kd in 0..kernel_size.0 {
420 for kh in 0..kernel_size.1 {
421 for kw in 0..kernel_size.2 {
422 let id = od * stride.0 + kd * dilation.0;
423 let ih = oh * stride.1 + kh * dilation.1;
424 let iw = ow * stride.2 + kw * dilation.2;
425
426 if id >= padding.0
427 && id < depth + padding.0
428 && ih >= padding.1
429 && ih < height + padding.1
430 && iw >= padding.2
431 && iw < width + padding.2
432 {
433 let real_id = id - padding.0;
434 let real_ih = ih - padding.1;
435 let real_iw = iw - padding.2;
436
437 if real_id < depth && real_ih < height && real_iw < width {
438 let in_idx = (((b * channels + c) * depth + real_id)
439 * height
440 + real_ih)
441 * width
442 + real_iw;
443 let val = input_data[in_idx];
444 if val > output_data[out_idx] {
445 output_data[out_idx] = val;
446 }
447 }
448 }
449 }
450 }
451 }
452
453 if output_data[out_idx] == f32::NEG_INFINITY {
455 output_data[out_idx] = 0.0;
456 }
457 }
458 }
459 }
460 }
461 }
462
463 Tensor::from_data(
464 output_data,
465 vec![batch_size, channels, out_depth, out_height, out_width],
466 input.device(),
467 )
468}
469
470pub fn avg_pool3d(
472 input: &Tensor,
473 kernel_size: (usize, usize, usize),
474 stride: Option<(usize, usize, usize)>,
475 padding: (usize, usize, usize),
476 ceil_mode: bool,
477 count_include_pad: bool,
478) -> TorshResult<Tensor> {
479 let stride = stride.unwrap_or(kernel_size);
480
481 let context = function_context("avg_pool3d");
482 validate_tensor_dims(input, 5, &context)?;
483
484 let shape = input.shape();
485 let dims = shape.dims();
486 let batch_size = dims[0];
487 let channels = dims[1];
488 let depth = dims[2];
489 let height = dims[3];
490 let width = dims[4];
491
492 let out_depth = if ceil_mode {
493 ((depth + 2 * padding.0 - kernel_size.0) as f32 / stride.0 as f32).ceil() as usize + 1
494 } else {
495 (depth + 2 * padding.0 - kernel_size.0) / stride.0 + 1
496 };
497
498 let out_height = if ceil_mode {
499 ((height + 2 * padding.1 - kernel_size.1) as f32 / stride.1 as f32).ceil() as usize + 1
500 } else {
501 (height + 2 * padding.1 - kernel_size.1) / stride.1 + 1
502 };
503
504 let out_width = if ceil_mode {
505 ((width + 2 * padding.2 - kernel_size.2) as f32 / stride.2 as f32).ceil() as usize + 1
506 } else {
507 (width + 2 * padding.2 - kernel_size.2) / stride.2 + 1
508 };
509
510 let output_size = batch_size * channels * out_depth * out_height * out_width;
511 let mut output_data = vec![0.0f32; output_size];
512 let input_data = input.to_vec()?;
513
514 for b in 0..batch_size {
515 for c in 0..channels {
516 for od in 0..out_depth {
517 for oh in 0..out_height {
518 for ow in 0..out_width {
519 let out_idx = (((b * channels + c) * out_depth + od) * out_height + oh)
520 * out_width
521 + ow;
522
523 let mut sum = 0.0f32;
524 let mut count = 0;
525
526 for kd in 0..kernel_size.0 {
527 for kh in 0..kernel_size.1 {
528 for kw in 0..kernel_size.2 {
529 let id = od * stride.0 + kd;
530 let ih = oh * stride.1 + kh;
531 let iw = ow * stride.2 + kw;
532
533 if count_include_pad
534 || (id >= padding.0
535 && id < depth + padding.0
536 && ih >= padding.1
537 && ih < height + padding.1
538 && iw >= padding.2
539 && iw < width + padding.2)
540 {
541 if id >= padding.0
542 && id < depth + padding.0
543 && ih >= padding.1
544 && ih < height + padding.1
545 && iw >= padding.2
546 && iw < width + padding.2
547 {
548 let real_id = id - padding.0;
549 let real_ih = ih - padding.1;
550 let real_iw = iw - padding.2;
551
552 if real_id < depth
553 && real_ih < height
554 && real_iw < width
555 {
556 let in_idx = (((b * channels + c) * depth
557 + real_id)
558 * height
559 + real_ih)
560 * width
561 + real_iw;
562 sum += input_data[in_idx];
563 }
564 }
565 count += 1;
566 }
567 }
568 }
569 }
570
571 if count > 0 {
572 output_data[out_idx] = sum / count as f32;
573 }
574 }
575 }
576 }
577 }
578 }
579
580 Tensor::from_data(
581 output_data,
582 vec![batch_size, channels, out_depth, out_height, out_width],
583 input.device(),
584 )
585}