Skip to main content

torsh_functional/pooling/
basic.rs

1//! Basic pooling operations: max and average pooling in 1D, 2D, and 3D
2
3use crate::utils::{calculate_pooling_output_size, function_context, validate_tensor_dims};
4use torsh_core::Result as TorshResult;
5use torsh_tensor::Tensor;
6
7/// 1D max pooling
8#[allow(clippy::too_many_arguments)]
9pub fn max_pool1d(
10    input: &Tensor,
11    kernel_size: usize,
12    stride: Option<usize>,
13    padding: usize,
14    dilation: usize,
15    return_indices: bool,
16) -> TorshResult<(Tensor, Option<Tensor>)> {
17    let stride = stride.unwrap_or(kernel_size);
18
19    let context = function_context("max_pool1d");
20    validate_tensor_dims(input, 3, &context)?;
21
22    let shape = input.shape();
23    let dims = shape.dims();
24    let batch_size = dims[0];
25    let channels = dims[1];
26    let length = dims[2];
27
28    let out_length = calculate_pooling_output_size(length, kernel_size, stride, padding, dilation);
29
30    let mut output_data = vec![f32::NEG_INFINITY; batch_size * channels * out_length];
31    let mut indices_data = if return_indices {
32        Some(vec![0i64; batch_size * channels * out_length])
33    } else {
34        None
35    };
36
37    let input_data = input.to_vec()?;
38
39    for b in 0..batch_size {
40        for c in 0..channels {
41            for ol in 0..out_length {
42                let out_idx = (b * channels + c) * out_length + ol;
43                let mut max_val = f32::NEG_INFINITY;
44                let mut max_idx = 0;
45
46                for kl in 0..kernel_size {
47                    let il = ol * stride + kl * dilation;
48
49                    if il >= padding && il < length + padding {
50                        let real_il = il - padding;
51
52                        if real_il < length {
53                            let in_idx = (b * channels + c) * length + real_il;
54                            let val = input_data[in_idx];
55
56                            if val > max_val {
57                                max_val = val;
58                                max_idx = in_idx as i64;
59                            }
60                        }
61                    }
62                }
63
64                output_data[out_idx] = max_val;
65                if let Some(ref mut indices) = indices_data {
66                    indices[out_idx] = max_idx;
67                }
68            }
69        }
70    }
71
72    let output = Tensor::from_data(
73        output_data,
74        vec![batch_size, channels, out_length],
75        input.device(),
76    )?;
77
78    let indices = if let Some(indices_data) = indices_data {
79        let indices_f32: Vec<f32> = indices_data.iter().map(|&idx| idx as f32).collect();
80        Some(Tensor::from_data(
81            indices_f32,
82            vec![batch_size, channels, out_length],
83            input.device(),
84        )?)
85    } else {
86        None
87    };
88
89    Ok((output, indices))
90}
91
92/// 2D max pooling
93#[allow(clippy::too_many_arguments)]
94pub fn max_pool2d(
95    input: &Tensor,
96    kernel_size: (usize, usize),
97    stride: Option<(usize, usize)>,
98    padding: (usize, usize),
99    dilation: (usize, usize),
100    ceil_mode: bool,
101    return_indices: bool,
102) -> TorshResult<(Tensor, Option<Tensor>)> {
103    let stride = stride.unwrap_or(kernel_size);
104
105    let context = function_context("max_pool2d");
106    validate_tensor_dims(input, 4, &context)?;
107
108    let shape = input.shape();
109    let dims = shape.dims();
110    let batch_size = dims[0];
111    let channels = dims[1];
112    let height = dims[2];
113    let width = dims[3];
114
115    let out_height = if ceil_mode {
116        ((height + 2 * padding.0 - dilation.0 * (kernel_size.0 - 1) - 1) as f32 / stride.0 as f32)
117            .ceil() as usize
118    } else {
119        calculate_pooling_output_size(height, kernel_size.0, stride.0, padding.0, dilation.0)
120    };
121
122    let out_width = if ceil_mode {
123        ((width + 2 * padding.1 - dilation.1 * (kernel_size.1 - 1) - 1) as f32 / stride.1 as f32)
124            .ceil() as usize
125    } else {
126        calculate_pooling_output_size(width, kernel_size.1, stride.1, padding.1, dilation.1)
127    };
128
129    let output_size = batch_size * channels * out_height * out_width;
130    let mut output_data = vec![f32::NEG_INFINITY; output_size];
131    let mut indices_data = if return_indices {
132        Some(vec![0i64; output_size])
133    } else {
134        None
135    };
136
137    let input_data = input.to_vec()?;
138
139    for b in 0..batch_size {
140        for c in 0..channels {
141            for oh in 0..out_height {
142                for ow in 0..out_width {
143                    let out_idx = ((b * channels + c) * out_height + oh) * out_width + ow;
144                    let mut max_val = f32::NEG_INFINITY;
145                    let mut max_idx = 0;
146
147                    for kh in 0..kernel_size.0 {
148                        for kw in 0..kernel_size.1 {
149                            let ih = oh * stride.0 + kh * dilation.0;
150                            let iw = ow * stride.1 + kw * dilation.1;
151
152                            if ih >= padding.0
153                                && ih < height + padding.0
154                                && iw >= padding.1
155                                && iw < width + padding.1
156                            {
157                                let real_ih = ih - padding.0;
158                                let real_iw = iw - padding.1;
159
160                                if real_ih < height && real_iw < width {
161                                    let in_idx =
162                                        ((b * channels + c) * height + real_ih) * width + real_iw;
163                                    let val = input_data[in_idx];
164
165                                    if val > max_val {
166                                        max_val = val;
167                                        max_idx = in_idx as i64;
168                                    }
169                                }
170                            }
171                        }
172                    }
173
174                    output_data[out_idx] = max_val;
175                    if let Some(ref mut indices) = indices_data {
176                        indices[out_idx] = max_idx;
177                    }
178                }
179            }
180        }
181    }
182
183    let output = Tensor::from_data(
184        output_data,
185        vec![batch_size, channels, out_height, out_width],
186        input.device(),
187    )?;
188
189    let indices = if let Some(indices_data) = indices_data {
190        let indices_f32: Vec<f32> = indices_data.iter().map(|&idx| idx as f32).collect();
191        Some(Tensor::from_data(
192            indices_f32,
193            vec![batch_size, channels, out_height, out_width],
194            input.device(),
195        )?)
196    } else {
197        None
198    };
199
200    Ok((output, indices))
201}
202
203/// 1D average pooling
204pub fn avg_pool1d(
205    input: &Tensor,
206    kernel_size: usize,
207    stride: Option<usize>,
208    padding: usize,
209    ceil_mode: bool,
210    count_include_pad: bool,
211) -> TorshResult<Tensor> {
212    let stride = stride.unwrap_or(kernel_size);
213
214    let context = function_context("avg_pool1d");
215    validate_tensor_dims(input, 3, &context)?;
216
217    let shape = input.shape();
218    let dims = shape.dims();
219    let batch_size = dims[0];
220    let channels = dims[1];
221    let length = dims[2];
222
223    let out_length = if ceil_mode {
224        ((length + 2 * padding - kernel_size) as f32 / stride as f32).ceil() as usize + 1
225    } else {
226        calculate_pooling_output_size(length, kernel_size, stride, padding, 1)
227    };
228
229    let mut output_data = vec![0.0f32; batch_size * channels * out_length];
230    let input_data = input.to_vec()?;
231
232    for b in 0..batch_size {
233        for c in 0..channels {
234            for ol in 0..out_length {
235                let out_idx = (b * channels + c) * out_length + ol;
236                let mut sum = 0.0f32;
237                let mut count = 0;
238
239                for kl in 0..kernel_size {
240                    let il = ol * stride + kl;
241
242                    if il >= padding && il < length + padding {
243                        let real_il = il - padding;
244
245                        if real_il < length {
246                            let in_idx = (b * channels + c) * length + real_il;
247                            sum += input_data[in_idx];
248                            count += 1;
249                        } else if count_include_pad {
250                            count += 1;
251                        }
252                    } else if count_include_pad {
253                        count += 1;
254                    }
255                }
256
257                if count > 0 {
258                    output_data[out_idx] = sum / count as f32;
259                }
260            }
261        }
262    }
263
264    Tensor::from_data(
265        output_data,
266        vec![batch_size, channels, out_length],
267        input.device(),
268    )
269}
270
271/// 2D average pooling
272pub fn avg_pool2d(
273    input: &Tensor,
274    kernel_size: (usize, usize),
275    stride: Option<(usize, usize)>,
276    padding: (usize, usize),
277    ceil_mode: bool,
278    count_include_pad: bool,
279    divisor_override: Option<usize>,
280) -> TorshResult<Tensor> {
281    let stride = stride.unwrap_or(kernel_size);
282
283    let context = function_context("avg_pool2d");
284    validate_tensor_dims(input, 4, &context)?;
285
286    let shape = input.shape();
287    let dims = shape.dims();
288    let batch_size = dims[0];
289    let channels = dims[1];
290    let height = dims[2];
291    let width = dims[3];
292
293    let out_height = if ceil_mode {
294        ((height + 2 * padding.0 - kernel_size.0) as f32 / stride.0 as f32).ceil() as usize + 1
295    } else {
296        calculate_pooling_output_size(height, kernel_size.0, stride.0, padding.0, 1)
297    };
298
299    let out_width = if ceil_mode {
300        ((width + 2 * padding.1 - kernel_size.1) as f32 / stride.1 as f32).ceil() as usize + 1
301    } else {
302        calculate_pooling_output_size(width, kernel_size.1, stride.1, padding.1, 1)
303    };
304
305    let output_size = batch_size * channels * out_height * out_width;
306    let mut output_data = vec![0.0f32; output_size];
307    let input_data = input.to_vec()?;
308
309    for b in 0..batch_size {
310        for c in 0..channels {
311            for oh in 0..out_height {
312                for ow in 0..out_width {
313                    let out_idx = ((b * channels + c) * out_height + oh) * out_width + ow;
314                    let mut sum = 0.0f32;
315                    let mut count = 0;
316
317                    for kh in 0..kernel_size.0 {
318                        for kw in 0..kernel_size.1 {
319                            let ih = oh * stride.0 + kh;
320                            let iw = ow * stride.1 + kw;
321
322                            if ih >= padding.0
323                                && ih < height + padding.0
324                                && iw >= padding.1
325                                && iw < width + padding.1
326                            {
327                                let real_ih = ih - padding.0;
328                                let real_iw = iw - padding.1;
329
330                                if real_ih < height && real_iw < width {
331                                    let in_idx =
332                                        ((b * channels + c) * height + real_ih) * width + real_iw;
333                                    sum += input_data[in_idx];
334                                    count += 1;
335                                } else if count_include_pad {
336                                    count += 1;
337                                }
338                            } else if count_include_pad {
339                                count += 1;
340                            }
341                        }
342                    }
343
344                    let divisor = divisor_override.unwrap_or(count) as f32;
345                    if divisor > 0.0 {
346                        output_data[out_idx] = sum / divisor;
347                    }
348                }
349            }
350        }
351    }
352
353    Tensor::from_data(
354        output_data,
355        vec![batch_size, channels, out_height, out_width],
356        input.device(),
357    )
358}
359
360/// Max pooling 3D
361pub fn max_pool3d(
362    input: &Tensor,
363    kernel_size: (usize, usize, usize),
364    stride: Option<(usize, usize, usize)>,
365    padding: (usize, usize, usize),
366    dilation: (usize, usize, usize),
367    ceil_mode: bool,
368) -> TorshResult<Tensor> {
369    let stride = stride.unwrap_or(kernel_size);
370
371    let context = function_context("max_pool3d");
372    validate_tensor_dims(input, 5, &context)?;
373
374    let shape = input.shape();
375    let dims = shape.dims();
376    let batch_size = dims[0];
377    let channels = dims[1];
378    let depth = dims[2];
379    let height = dims[3];
380    let width = dims[4];
381
382    let effective_kernel = (
383        (kernel_size.0 - 1) * dilation.0 + 1,
384        (kernel_size.1 - 1) * dilation.1 + 1,
385        (kernel_size.2 - 1) * dilation.2 + 1,
386    );
387
388    let out_depth = if ceil_mode {
389        ((depth + 2 * padding.0 - effective_kernel.0) as f32 / stride.0 as f32).ceil() as usize + 1
390    } else {
391        (depth + 2 * padding.0 - effective_kernel.0) / stride.0 + 1
392    };
393
394    let out_height = if ceil_mode {
395        ((height + 2 * padding.1 - effective_kernel.1) as f32 / stride.1 as f32).ceil() as usize + 1
396    } else {
397        (height + 2 * padding.1 - effective_kernel.1) / stride.1 + 1
398    };
399
400    let out_width = if ceil_mode {
401        ((width + 2 * padding.2 - effective_kernel.2) as f32 / stride.2 as f32).ceil() as usize + 1
402    } else {
403        (width + 2 * padding.2 - effective_kernel.2) / stride.2 + 1
404    };
405
406    let output_size = batch_size * channels * out_depth * out_height * out_width;
407    let mut output_data = vec![f32::NEG_INFINITY; output_size];
408    let input_data = input.to_vec()?;
409
410    for b in 0..batch_size {
411        for c in 0..channels {
412            for od in 0..out_depth {
413                for oh in 0..out_height {
414                    for ow in 0..out_width {
415                        let out_idx = (((b * channels + c) * out_depth + od) * out_height + oh)
416                            * out_width
417                            + ow;
418
419                        for kd in 0..kernel_size.0 {
420                            for kh in 0..kernel_size.1 {
421                                for kw in 0..kernel_size.2 {
422                                    let id = od * stride.0 + kd * dilation.0;
423                                    let ih = oh * stride.1 + kh * dilation.1;
424                                    let iw = ow * stride.2 + kw * dilation.2;
425
426                                    if id >= padding.0
427                                        && id < depth + padding.0
428                                        && ih >= padding.1
429                                        && ih < height + padding.1
430                                        && iw >= padding.2
431                                        && iw < width + padding.2
432                                    {
433                                        let real_id = id - padding.0;
434                                        let real_ih = ih - padding.1;
435                                        let real_iw = iw - padding.2;
436
437                                        if real_id < depth && real_ih < height && real_iw < width {
438                                            let in_idx = (((b * channels + c) * depth + real_id)
439                                                * height
440                                                + real_ih)
441                                                * width
442                                                + real_iw;
443                                            let val = input_data[in_idx];
444                                            if val > output_data[out_idx] {
445                                                output_data[out_idx] = val;
446                                            }
447                                        }
448                                    }
449                                }
450                            }
451                        }
452
453                        // Handle case where no valid input was found
454                        if output_data[out_idx] == f32::NEG_INFINITY {
455                            output_data[out_idx] = 0.0;
456                        }
457                    }
458                }
459            }
460        }
461    }
462
463    Tensor::from_data(
464        output_data,
465        vec![batch_size, channels, out_depth, out_height, out_width],
466        input.device(),
467    )
468}
469
470/// Average pooling 3D
471pub fn avg_pool3d(
472    input: &Tensor,
473    kernel_size: (usize, usize, usize),
474    stride: Option<(usize, usize, usize)>,
475    padding: (usize, usize, usize),
476    ceil_mode: bool,
477    count_include_pad: bool,
478) -> TorshResult<Tensor> {
479    let stride = stride.unwrap_or(kernel_size);
480
481    let context = function_context("avg_pool3d");
482    validate_tensor_dims(input, 5, &context)?;
483
484    let shape = input.shape();
485    let dims = shape.dims();
486    let batch_size = dims[0];
487    let channels = dims[1];
488    let depth = dims[2];
489    let height = dims[3];
490    let width = dims[4];
491
492    let out_depth = if ceil_mode {
493        ((depth + 2 * padding.0 - kernel_size.0) as f32 / stride.0 as f32).ceil() as usize + 1
494    } else {
495        (depth + 2 * padding.0 - kernel_size.0) / stride.0 + 1
496    };
497
498    let out_height = if ceil_mode {
499        ((height + 2 * padding.1 - kernel_size.1) as f32 / stride.1 as f32).ceil() as usize + 1
500    } else {
501        (height + 2 * padding.1 - kernel_size.1) / stride.1 + 1
502    };
503
504    let out_width = if ceil_mode {
505        ((width + 2 * padding.2 - kernel_size.2) as f32 / stride.2 as f32).ceil() as usize + 1
506    } else {
507        (width + 2 * padding.2 - kernel_size.2) / stride.2 + 1
508    };
509
510    let output_size = batch_size * channels * out_depth * out_height * out_width;
511    let mut output_data = vec![0.0f32; output_size];
512    let input_data = input.to_vec()?;
513
514    for b in 0..batch_size {
515        for c in 0..channels {
516            for od in 0..out_depth {
517                for oh in 0..out_height {
518                    for ow in 0..out_width {
519                        let out_idx = (((b * channels + c) * out_depth + od) * out_height + oh)
520                            * out_width
521                            + ow;
522
523                        let mut sum = 0.0f32;
524                        let mut count = 0;
525
526                        for kd in 0..kernel_size.0 {
527                            for kh in 0..kernel_size.1 {
528                                for kw in 0..kernel_size.2 {
529                                    let id = od * stride.0 + kd;
530                                    let ih = oh * stride.1 + kh;
531                                    let iw = ow * stride.2 + kw;
532
533                                    if count_include_pad
534                                        || (id >= padding.0
535                                            && id < depth + padding.0
536                                            && ih >= padding.1
537                                            && ih < height + padding.1
538                                            && iw >= padding.2
539                                            && iw < width + padding.2)
540                                    {
541                                        if id >= padding.0
542                                            && id < depth + padding.0
543                                            && ih >= padding.1
544                                            && ih < height + padding.1
545                                            && iw >= padding.2
546                                            && iw < width + padding.2
547                                        {
548                                            let real_id = id - padding.0;
549                                            let real_ih = ih - padding.1;
550                                            let real_iw = iw - padding.2;
551
552                                            if real_id < depth
553                                                && real_ih < height
554                                                && real_iw < width
555                                            {
556                                                let in_idx = (((b * channels + c) * depth
557                                                    + real_id)
558                                                    * height
559                                                    + real_ih)
560                                                    * width
561                                                    + real_iw;
562                                                sum += input_data[in_idx];
563                                            }
564                                        }
565                                        count += 1;
566                                    }
567                                }
568                            }
569                        }
570
571                        if count > 0 {
572                            output_data[out_idx] = sum / count as f32;
573                        }
574                    }
575                }
576            }
577        }
578    }
579
580    Tensor::from_data(
581        output_data,
582        vec![batch_size, channels, out_depth, out_height, out_width],
583        input.device(),
584    )
585}