1use bon::bon;
4use svod_dtype::DType;
5use svod_ir::{ConstValue, SInt, UOp};
6
7use crate::Tensor;
8use crate::error::DivisibilitySnafu;
9use crate::reduce::AxisSpec;
10
11use super::pad::apply_ceil_mode;
12
13type Result<T> = crate::Result<T>;
14
15impl Tensor {
16 pub fn pool(&self, kernel: &[usize], stride: &[usize], dilation: &[usize]) -> Result<Tensor> {
24 let shape = self.shape()?;
25 let ndim = shape.len();
26 let n_spatial = kernel.len();
27 let n_batch = ndim - n_spatial;
28
29 if ndim < n_spatial {
30 return Err(crate::error::Error::IrConstruction {
31 details: format!("can't pool {ndim}D with {n_spatial}D kernel"),
32 });
33 }
34 if kernel.len() != stride.len() {
35 return Err(crate::error::Error::IrConstruction {
36 details: format!("kernel/stride length mismatch: {} vs {}", kernel.len(), stride.len()),
37 });
38 }
39 if kernel.len() != dilation.len() {
40 return Err(crate::error::Error::IrConstruction {
41 details: format!("kernel/dilation length mismatch: {} vs {}", kernel.len(), dilation.len()),
42 });
43 }
44
45 let i_: Vec<SInt> = (0..n_spatial).map(|j| shape[n_batch + j].clone()).collect();
47
48 for j in 0..n_spatial {
50 if let Some(i) = i_[j].as_const()
51 && dilation[j] * (kernel[j] - 1) >= i
52 {
53 return Err(crate::error::Error::IrConstruction {
54 details: format!(
55 "kernel size {} (dilated {}) > input size {}",
56 kernel[j],
57 dilation[j] * (kernel[j] - 1) + 1,
58 i
59 ),
60 });
61 }
62 }
63
64 let o_: Vec<SInt> =
67 (0..n_spatial).map(|j| (&i_[j] - dilation[j] * (kernel[j] - 1)).ceildiv(&SInt::from(stride[j]))).collect();
68
69 let f_: Vec<SInt> = (0..n_spatial)
71 .map(|j| SInt::from(1usize).smax(&(&o_[j] * stride[j] - dilation[j]).ceildiv(&i_[j])))
72 .collect();
73
74 let noop: Vec<Option<(SInt, SInt)>> = vec![None; n_batch];
76 let batch_sint: Vec<SInt> = shape.iter().take(n_batch).cloned().collect();
77
78 let mut repeats: Vec<SInt> = vec![SInt::from(1usize); n_batch];
81 for j in 0..n_spatial {
82 repeats.push((kernel[j] * (&i_[j] * &f_[j] + dilation[j])).ceildiv(&i_[j]));
83 }
84 let mut x = self.repeat(&repeats)?;
85
86 let mut shrink: Vec<Option<(SInt, SInt)>> = noop.clone();
88 for j in 0..n_spatial {
89 shrink.push(Some((SInt::from(0usize), kernel[j] * (&i_[j] * &f_[j] + dilation[j]))));
90 }
91 x = x.try_shrink(shrink)?;
92
93 let mut reshape_dims: Vec<SInt> = batch_sint.clone();
95 for j in 0..n_spatial {
96 reshape_dims.push(kernel[j].into());
97 reshape_dims.push(&i_[j] * &f_[j] + dilation[j]);
98 }
99 x = x.try_reshape(reshape_dims)?;
100
101 let mut shrink: Vec<Option<(SInt, SInt)>> = noop.clone();
103 for j in 0..n_spatial {
104 shrink.push(Some((SInt::from(0usize), SInt::from(kernel[j]))));
105 shrink.push(Some((SInt::from(0usize), &o_[j] * stride[j])));
106 }
107 x = x.try_shrink(shrink)?;
108
109 let mut reshape_dims: Vec<SInt> = batch_sint.clone();
111 for j in 0..n_spatial {
112 reshape_dims.push(kernel[j].into());
113 reshape_dims.push(o_[j].clone());
114 reshape_dims.push(stride[j].into());
115 }
116 x = x.try_reshape(reshape_dims)?;
117
118 let mut shrink: Vec<Option<(SInt, SInt)>> = noop.clone();
120 for j in 0..n_spatial {
121 shrink.push(Some((SInt::from(0usize), SInt::from(kernel[j]))));
122 shrink.push(Some((SInt::from(0usize), o_[j].clone())));
123 shrink.push(Some((SInt::from(0usize), SInt::from(1usize))));
124 }
125 x = x.try_shrink(shrink)?;
126
127 let mut reshape_dims: Vec<SInt> = batch_sint.clone();
129 for j in 0..n_spatial {
130 reshape_dims.push(kernel[j].into());
131 reshape_dims.push(o_[j].clone());
132 }
133 x = x.try_reshape(reshape_dims)?;
134
135 let mut perm: Vec<isize> = (0..n_batch as isize).collect();
137 for j in 0..n_spatial {
138 perm.push(n_batch as isize + j as isize * 2 + 1); }
140 for j in 0..n_spatial {
141 perm.push(n_batch as isize + j as isize * 2); }
143 x = x.try_permute(&perm)?;
144
145 Ok(x)
146 }
147}
148
149#[bon]
150impl Tensor {
151 #[builder]
203 pub fn avg_pool2d(
204 &self,
205 kernel_size: &[usize],
206 stride: Option<&[usize]>,
207 dilation: Option<&[usize]>,
208 padding: Option<&[(isize, isize)]>,
209 #[builder(default = true)] count_include_pad: bool,
210 #[builder(default = false)] ceil_mode: bool,
211 ) -> Result<Tensor> {
212 let n_spatial = kernel_size.len();
213 let default_dilation: Vec<usize> = vec![1; n_spatial];
214 let stride = stride.unwrap_or(kernel_size);
215 let dilation = dilation.unwrap_or(&default_dilation);
216 let no_pad: Vec<(isize, isize)> = vec![(0, 0); n_spatial];
217 let padding = padding.unwrap_or(&no_pad);
218
219 let reduce_axes: Vec<isize> = (0..n_spatial).map(|j| -(1 + j as isize)).collect();
220 let axes = AxisSpec::Multiple(reduce_axes);
221
222 let shape = self.shape()?;
223 let n_batch = shape.len() - n_spatial;
224 let input_spatial: Vec<SInt> = shape[n_batch..].to_vec();
225
226 let reg_pads = padding.to_vec();
227 let ceil_pads = if ceil_mode {
228 apply_ceil_mode(®_pads, &input_spatial, kernel_size, stride, dilation)
229 } else {
230 reg_pads.clone()
231 };
232
233 let pad_and_pool = |x: &Tensor, pads: &[(isize, isize)]| -> Result<Tensor> {
234 let mut out = x.clone();
235 if pads.iter().any(|&(b, e)| b != 0 || e != 0) {
236 let mut full_pad: Vec<(isize, isize)> = vec![(0, 0); n_batch];
237 full_pad.extend_from_slice(pads);
238 out = out.try_pad(&full_pad)?;
239 }
240 out.pool(kernel_size, stride, dilation)
241 };
242
243 if !count_include_pad {
244 let pads = if ceil_mode { &ceil_pads } else { ®_pads };
246 let pooled = pad_and_pool(self, pads)?;
247 let sum_x = pooled.sum_with().axes(axes.clone()).keepdim(false).call()?;
248 let dtype = self.uop().dtype();
250 let ones = Tensor::new(UOp::const_(dtype, ConstValue::Float(1.0)));
251 let ones = ones.broadcast_to(&self.shape()?)?;
252 let pooled_ones = pad_and_pool(&ones, pads)?;
253 let sum_ones = pooled_ones.sum_with().axes(axes).keepdim(false).call()?;
254 return sum_x.try_div(&sum_ones);
255 }
256
257 if !ceil_mode {
258 let pooled = pad_and_pool(self, ®_pads)?;
260 return pooled.mean(axes);
261 }
262
263 let pooled = pad_and_pool(self, &ceil_pads)?;
267 let sum_x = pooled.sum_with().axes(axes.clone()).keepdim(false).call()?;
268
269 let mut padded_self = self.clone();
272 if reg_pads.iter().any(|&(b, e)| b != 0 || e != 0) {
273 let mut full_pad: Vec<(isize, isize)> = vec![(0, 0); n_batch];
274 full_pad.extend_from_slice(®_pads);
275 padded_self = padded_self.try_pad(&full_pad)?;
276 }
277 let ones_reg = padded_self.one()?;
278 let extra_pads: Vec<(isize, isize)> =
279 ceil_pads.iter().zip(reg_pads.iter()).map(|(c, r)| (c.0 - r.0, c.1 - r.1)).collect();
280 let pooled_ones = pad_and_pool(&ones_reg, &extra_pads)?;
281 let sum_ones = pooled_ones.sum_with().axes(axes).keepdim(false).call()?;
282 sum_x.try_div(&sum_ones)
283 }
284
285 #[builder]
325 pub fn max_pool2d(
326 &self,
327 kernel_size: &[usize],
328 stride: Option<&[usize]>,
329 dilation: Option<&[usize]>,
330 padding: Option<&[(isize, isize)]>,
331 #[builder(default = false)] ceil_mode: bool,
332 ) -> Result<Tensor> {
333 let n_spatial = kernel_size.len();
334 let default_dilation: Vec<usize> = vec![1; n_spatial];
335 let stride = stride.unwrap_or(kernel_size);
336 let dilation = dilation.unwrap_or(&default_dilation);
337 let no_pad: Vec<(isize, isize)> = vec![(0, 0); n_spatial];
338 let padding = padding.unwrap_or(&no_pad);
339
340 let pads = if ceil_mode {
341 let shape = self.shape()?;
342 let n_batch = shape.len() - n_spatial;
343 let input_spatial: Vec<SInt> = shape[n_batch..].to_vec();
344 apply_ceil_mode(padding, &input_spatial, kernel_size, stride, dilation)
345 } else {
346 padding.to_vec()
347 };
348
349 let reduce_axes: Vec<isize> = (0..n_spatial).map(|j| -(1 + j as isize)).collect();
350 let axes = AxisSpec::Multiple(reduce_axes);
351
352 let mut x = self.clone();
353 if pads.iter().any(|&(b, e)| b != 0 || e != 0) {
354 let mut full_pad: Vec<(isize, isize)> = vec![(0, 0); self.ndim()? - n_spatial];
355 full_pad.extend_from_slice(&pads);
356 let fill = if self.uop().dtype().is_float() { f64::NEG_INFINITY } else { i64::MIN as f64 };
357 x = x.try_pad_value(&full_pad, fill)?;
358 }
359
360 let pooled = x.pool(kernel_size, stride, dilation)?;
361 pooled.max(axes)
362 }
363
364 #[builder]
390 pub fn max_pool2d_with_indices(
391 &self,
392 kernel_size: &[usize],
393 stride: Option<&[usize]>,
394 dilation: Option<&[usize]>,
395 padding: Option<&[(isize, isize)]>,
396 #[builder(default = false)] ceil_mode: bool,
397 ) -> Result<(Tensor, Tensor)> {
398 let n_spatial = kernel_size.len();
399 let default_dilation: Vec<usize> = vec![1; n_spatial];
400 let stride = stride.unwrap_or(kernel_size);
401 let dilation = dilation.unwrap_or(&default_dilation);
402 let no_pad: Vec<(isize, isize)> = vec![(0, 0); n_spatial];
403 let padding = padding.unwrap_or(&no_pad);
404
405 let shape = self.shape()?;
406 let n_batch = shape.len() - n_spatial;
407
408 let pads = if ceil_mode {
409 let input_spatial: Vec<SInt> = shape[n_batch..].to_vec();
410 apply_ceil_mode(padding, &input_spatial, kernel_size, stride, dilation)
411 } else {
412 padding.to_vec()
413 };
414
415 let reduce_axes: Vec<isize> = (0..n_spatial).map(|j| -(1 + j as isize)).collect();
416 let axes = AxisSpec::Multiple(reduce_axes.clone());
417
418 let mut x = self.clone();
420 if pads.iter().any(|&(b, e)| b != 0 || e != 0) {
421 let mut full_pad: Vec<(isize, isize)> = vec![(0, 0); n_batch];
422 full_pad.extend_from_slice(&pads);
423 let fill = if self.uop().dtype().is_float() { f64::NEG_INFINITY } else { i64::MIN as f64 };
424 x = x.try_pad_value(&full_pad, fill)?;
425 }
426 let pooled = x.pool(kernel_size, stride, dilation)?;
427 let values = pooled.max_with().axes(axes.clone()).keepdim(false).call()?;
428
429 let spatial_sz: usize = (0..n_spatial).map(|j| shape[n_batch + j].as_const().unwrap()).product();
431
432 let idx_range = Tensor::arange(spatial_sz as i64, Some(0), Some(-1))?;
434 let spatial_dims: Vec<isize> =
436 (0..n_spatial).map(|j| shape[n_batch + j].as_const().unwrap() as isize).collect();
437 let mut idx_shape: Vec<isize> = vec![1; n_batch];
438 idx_shape.extend_from_slice(&spatial_dims);
439 let idx = idx_range.try_reshape(&idx_shape)?;
440
441 let mut idx_padded = idx;
443 if pads.iter().any(|&(b, e)| b != 0 || e != 0) {
444 let mut full_pad: Vec<(isize, isize)> = vec![(0, 0); n_batch];
445 full_pad.extend_from_slice(&pads);
446 idx_padded = idx_padded.try_pad(&full_pad)?;
447 }
448 let pooled_idx = idx_padded.pool(kernel_size, stride, dilation)?;
449
450 let pooled_max = pooled.max_with().axes(axes.clone()).keepdim(true).call()?;
452 let mask = pooled.try_eq(&pooled_max)?;
453
454 let masked_idx = mask.cast(DType::Int32)?.try_mul(&pooled_idx)?;
456 let max_idx = masked_idx.max_with().axes(axes).keepdim(false).call()?;
457
458 let sz_t = Tensor::const_(ConstValue::Int(spatial_sz as i64), DType::Int32);
460 let indices = sz_t.try_sub(&max_idx)?;
461
462 Ok((values, indices))
463 }
464
465 #[builder]
495 pub fn max_unpool2d(
496 &self,
497 indices: &Tensor,
498 kernel_size: &[usize],
499 stride: Option<&[usize]>,
500 padding: Option<&[(isize, isize)]>,
501 output_size: Option<&[usize]>,
502 ) -> Result<Tensor> {
503 let shape = self.shape()?;
504 let ndim = shape.len();
505 let n_spatial = kernel_size.len();
506 let n_batch = ndim - n_spatial;
507
508 let spatial_shape: Vec<usize> = (0..n_spatial).map(|j| shape[n_batch + j].as_const().unwrap()).collect();
509
510 let stride = stride.unwrap_or(kernel_size);
512 let no_pad: Vec<(isize, isize)> = vec![(0, 0); n_spatial];
513 let padding = padding.unwrap_or(&no_pad);
514 let inferred_spatial: Vec<usize> = (0..n_spatial)
515 .map(|j| {
516 let (pa, pb) = padding[j];
517 (spatial_shape[j] - 1) * stride[j] - (pa as usize + pb as usize) + kernel_size[j]
518 })
519 .collect();
520
521 let inferred_numel: usize = inferred_spatial.iter().product();
522 let bs: usize = (0..n_batch).map(|j| shape[j].as_const().unwrap()).product();
523
524 let num_pooled: usize = spatial_shape.iter().product();
526 let vals_flat = self.try_reshape([bs as isize, 1, num_pooled as isize])?;
527 let idx_flat = indices.try_reshape([bs as isize, 1, num_pooled as isize])?;
528
529 let arange = Tensor::arange(inferred_numel as i64, None, None)?.cast(indices.uop().dtype())?.try_reshape([
531 1,
532 inferred_numel as isize,
533 1,
534 ])?;
535 let one_hot = idx_flat.try_eq(&arange)?;
536
537 let zero = Tensor::const_(0.0f64, self.uop().dtype());
539 let placed = vals_flat.where_(&one_hot, &zero)?;
540 let result = placed.sum(-1isize)?;
541
542 let batch_dims: Vec<isize> = (0..n_batch).map(|j| shape[j].as_const().unwrap() as isize).collect();
544 let mut inferred_shape: Vec<isize> = batch_dims.clone();
545 inferred_shape.extend(inferred_spatial.iter().map(|&s| s as isize));
546 let result = result.try_reshape(&inferred_shape)?;
547
548 if let Some(os) = output_size {
550 let out_spatial = &os[os.len() - n_spatial..];
551 if out_spatial != inferred_spatial.as_slice() {
552 let mut pad_spec: Vec<(isize, isize)> = vec![(0, 0); n_batch];
553 for j in 0..n_spatial {
554 pad_spec.push((0, (out_spatial[j] - inferred_spatial[j]) as isize));
555 }
556 return result.try_pad(&pad_spec);
557 }
558 }
559 Ok(result)
560 }
561
562 #[builder]
594 pub fn col2im(
595 &self,
596 image_shape: &[usize],
597 block_shape: &[usize],
598 strides: Option<&[usize]>,
599 pads: Option<&[(isize, isize)]>,
600 dilations: Option<&[usize]>,
601 ) -> Result<Tensor> {
602 let n_spatial = image_shape.len();
603 let no_strides: Vec<usize> = vec![1; n_spatial];
604 let no_pads: Vec<(isize, isize)> = vec![(0, 0); n_spatial];
605 let no_dilations: Vec<usize> = vec![1; n_spatial];
606 let strides = strides.unwrap_or(&no_strides);
607 let pads = pads.unwrap_or(&no_pads);
608 let dilations = dilations.unwrap_or(&no_dilations);
609
610 let shape = self.shape()?;
611 let n = shape[0].as_const().unwrap();
612 let c_times_bl: usize = shape[1].as_const().unwrap();
613 let bl: usize = block_shape.iter().product();
614 snafu::ensure!(
615 c_times_bl.is_multiple_of(bl),
616 DivisibilitySnafu {
617 op: "col2im",
618 lhs_name: "C*block_size",
619 lhs: c_times_bl,
620 rhs_name: "block_size",
621 rhs: bl
622 }
623 );
624 let c = c_times_bl / bl;
625
626 let padded_img: Vec<usize> =
628 (0..n_spatial).map(|i| (image_shape[i] as isize + pads[i].0 + pads[i].1) as usize).collect();
629
630 let l_spatial: Vec<usize> = (0..n_spatial)
632 .map(|i| {
633 let effective_k = dilations[i] * (block_shape[i] - 1) + 1;
634 (padded_img[i] - effective_k) / strides[i] + 1
635 })
636 .collect();
637
638 let nc = n * c;
640 let mut data_shape: Vec<isize> = vec![nc as isize];
641 data_shape.extend(block_shape.iter().map(|&s| s as isize));
642 data_shape.extend(l_spatial.iter().map(|&s| s as isize));
643 let data = self.try_reshape(&data_shape)?;
644
645 let mut out_dims: Vec<usize> = vec![nc];
647 out_dims.extend_from_slice(&padded_img);
648 let mut result = Tensor::full(&out_dims, 0.0f64, self.uop().dtype())?;
649
650 for be in 0..bl {
652 let mut kpos = vec![0usize; n_spatial];
654 let mut rem = be;
655 for i in (0..n_spatial).rev() {
656 kpos[i] = rem % block_shape[i];
657 rem /= block_shape[i];
658 }
659
660 let mut shrink_ranges: Vec<(isize, isize)> = vec![(0, nc as isize)];
663 for &k in kpos.iter().take(n_spatial) {
664 shrink_ranges.push((k as isize, k as isize + 1));
665 }
666 for &l in l_spatial.iter().take(n_spatial) {
667 shrink_ranges.push((0, l as isize));
668 }
669 let slice = data.try_shrink(&shrink_ranges)?;
670 let mut sq_shape: Vec<isize> = vec![nc as isize];
672 sq_shape.extend(l_spatial.iter().map(|&s| s as isize));
673 let mut slice = slice.try_reshape(&sq_shape)?;
674
675 for j in 0..n_spatial {
677 let dim = 1 + j;
678 let l_j = l_spatial[j];
679
680 if strides[j] > 1 {
682 let s = strides[j];
683 let ndim = slice.shape()?.len();
684 let mut sh: Vec<isize> = slice.shape()?.iter().map(|d| d.as_const().unwrap() as isize).collect();
686 sh.insert(dim + 1, 1);
687 slice = slice.try_reshape(&sh)?;
688
689 let mut pad_spec: Vec<(isize, isize)> = vec![(0, 0); ndim + 1];
690 pad_spec[dim + 1] = (0, (s - 1) as isize);
691 slice = slice.try_pad(&pad_spec)?;
692
693 sh[dim] = (l_j * s) as isize;
694 sh.remove(dim + 1);
695 slice = slice.try_reshape(&sh)?;
696
697 let dilated_l = (l_j - 1) * s + 1;
698 let mut sr: Vec<(isize, isize)> =
699 slice.shape()?.iter().map(|d| (0, d.as_const().unwrap() as isize)).collect();
700 sr[dim] = (0, dilated_l as isize);
701 slice = slice.try_shrink(&sr)?;
702 }
703
704 let left = kpos[j] * dilations[j];
706 let right = (block_shape[j] - 1 - kpos[j]) * dilations[j];
707 if left > 0 || right > 0 {
708 let mut pad_spec: Vec<(isize, isize)> = vec![(0, 0); slice.shape()?.len()];
709 pad_spec[dim] = (left as isize, right as isize);
710 slice = slice.try_pad(&pad_spec)?;
711 }
712 }
713
714 result = result.try_add(&slice)?;
715 }
716
717 let mut shrink_ranges: Vec<(isize, isize)> = vec![(0, nc as isize)];
719 for j in 0..n_spatial {
720 shrink_ranges.push((pads[j].0, pads[j].0 + image_shape[j] as isize));
721 }
722 let result = result.try_shrink(&shrink_ranges)?;
723
724 let mut final_shape: Vec<isize> = vec![n as isize, c as isize];
726 final_shape.extend(image_shape.iter().map(|&s| s as isize));
727 result.try_reshape(&final_shape)
728 }
729}