Skip to main content

svod_tensor/nn/
pool.rs

1//! Sliding-window pooling: pool, avg_pool2d, max_pool2d, max_pool2d_with_indices.
2
3use bon::bon;
4use svod_dtype::DType;
5use svod_ir::{ConstValue, SInt, UOp};
6
7use crate::Tensor;
8use crate::error::DivisibilitySnafu;
9use crate::reduce::AxisSpec;
10
11use super::pad::apply_ceil_mode;
12
13type Result<T> = crate::Result<T>;
14
15impl Tensor {
16    /// Sliding window extraction via shape manipulation (Tinygrad's `_pool`).
17    ///
18    /// Input: `(..., *spatial)` &rarr; Output: `(..., *out_spatial, *kernel)`.
19    ///
20    /// This is a low-level building block for pooling and convolution. It extracts
21    /// all sliding windows of the given kernel size, stride, and dilation from the
22    /// spatial dimensions, appending the kernel dimensions at the end.
23    pub fn pool(&self, kernel: &[usize], stride: &[usize], dilation: &[usize]) -> Result<Tensor> {
24        let shape = self.shape()?;
25        let ndim = shape.len();
26        let n_spatial = kernel.len();
27        let n_batch = ndim - n_spatial;
28
29        if ndim < n_spatial {
30            return Err(crate::error::Error::IrConstruction {
31                details: format!("can't pool {ndim}D with {n_spatial}D kernel"),
32            });
33        }
34        if kernel.len() != stride.len() {
35            return Err(crate::error::Error::IrConstruction {
36                details: format!("kernel/stride length mismatch: {} vs {}", kernel.len(), stride.len()),
37            });
38        }
39        if kernel.len() != dilation.len() {
40            return Err(crate::error::Error::IrConstruction {
41                details: format!("kernel/dilation length mismatch: {} vs {}", kernel.len(), dilation.len()),
42            });
43        }
44
45        // Spatial dims as SInt — works for both concrete and symbolic.
46        let i_: Vec<SInt> = (0..n_spatial).map(|j| shape[n_batch + j].clone()).collect();
47
48        // Validate: kernel must fit in input (concrete dims only — symbolic skips check).
49        for j in 0..n_spatial {
50            if let Some(i) = i_[j].as_const()
51                && dilation[j] * (kernel[j] - 1) >= i
52            {
53                return Err(crate::error::Error::IrConstruction {
54                    details: format!(
55                        "kernel size {} (dilated {}) > input size {}",
56                        kernel[j],
57                        dilation[j] * (kernel[j] - 1) + 1,
58                        i
59                    ),
60                });
61            }
62        }
63
64        // Pool formulas — SInt arithmetic: concrete folds inline, symbolic creates UOp graph.
65        // o_[j] = ceildiv(i_[j] - dilation[j] * (kernel[j] - 1), stride[j])
66        let o_: Vec<SInt> =
67            (0..n_spatial).map(|j| (&i_[j] - dilation[j] * (kernel[j] - 1)).ceildiv(&SInt::from(stride[j]))).collect();
68
69        // f_[j] = max(1, ceildiv(o_[j] * stride[j] - dilation[j], i_[j]))
70        let f_: Vec<SInt> = (0..n_spatial)
71            .map(|j| SInt::from(1usize).smax(&(&o_[j] * stride[j] - dilation[j]).ceildiv(&i_[j])))
72            .collect();
73
74        // Batch dims: None in shrink (identity), SInt in reshape.
75        let noop: Vec<Option<(SInt, SInt)>> = vec![None; n_batch];
76        let batch_sint: Vec<SInt> = shape.iter().take(n_batch).cloned().collect();
77
78        // Step 1: repeat
79        // repeat_count = ceildiv(k * (i*f + d), i)
80        let mut repeats: Vec<SInt> = vec![SInt::from(1usize); n_batch];
81        for j in 0..n_spatial {
82            repeats.push((kernel[j] * (&i_[j] * &f_[j] + dilation[j])).ceildiv(&i_[j]));
83        }
84        let mut x = self.repeat(&repeats)?;
85
86        // Step 2: shrink to exact needed size
87        let mut shrink: Vec<Option<(SInt, SInt)>> = noop.clone();
88        for j in 0..n_spatial {
89            shrink.push(Some((SInt::from(0usize), kernel[j] * (&i_[j] * &f_[j] + dilation[j]))));
90        }
91        x = x.try_shrink(shrink)?;
92
93        // Step 3: reshape to interleave kernel and spatial dims
94        let mut reshape_dims: Vec<SInt> = batch_sint.clone();
95        for j in 0..n_spatial {
96            reshape_dims.push(kernel[j].into());
97            reshape_dims.push(&i_[j] * &f_[j] + dilation[j]);
98        }
99        x = x.try_reshape(reshape_dims)?;
100
101        // Step 4: shrink for stride
102        let mut shrink: Vec<Option<(SInt, SInt)>> = noop.clone();
103        for j in 0..n_spatial {
104            shrink.push(Some((SInt::from(0usize), SInt::from(kernel[j]))));
105            shrink.push(Some((SInt::from(0usize), &o_[j] * stride[j])));
106        }
107        x = x.try_shrink(shrink)?;
108
109        // Step 5: reshape to separate stride: K_j, o_j, S_j
110        let mut reshape_dims: Vec<SInt> = batch_sint.clone();
111        for j in 0..n_spatial {
112            reshape_dims.push(kernel[j].into());
113            reshape_dims.push(o_[j].clone());
114            reshape_dims.push(stride[j].into());
115        }
116        x = x.try_reshape(reshape_dims)?;
117
118        // Step 6: shrink stride dim to 1
119        let mut shrink: Vec<Option<(SInt, SInt)>> = noop.clone();
120        for j in 0..n_spatial {
121            shrink.push(Some((SInt::from(0usize), SInt::from(kernel[j]))));
122            shrink.push(Some((SInt::from(0usize), o_[j].clone())));
123            shrink.push(Some((SInt::from(0usize), SInt::from(1usize))));
124        }
125        x = x.try_shrink(shrink)?;
126
127        // Step 7: reshape to collapse stride dim
128        let mut reshape_dims: Vec<SInt> = batch_sint.clone();
129        for j in 0..n_spatial {
130            reshape_dims.push(kernel[j].into());
131            reshape_dims.push(o_[j].clone());
132        }
133        x = x.try_reshape(reshape_dims)?;
134
135        // Step 8: permute to move kernel dims to end
136        let mut perm: Vec<isize> = (0..n_batch as isize).collect();
137        for j in 0..n_spatial {
138            perm.push(n_batch as isize + j as isize * 2 + 1); // output spatial
139        }
140        for j in 0..n_spatial {
141            perm.push(n_batch as isize + j as isize * 2); // kernel
142        }
143        x = x.try_permute(&perm)?;
144
145        Ok(x)
146    }
147}
148
149#[bon]
150impl Tensor {
151    /// Average pooling over spatial dimensions.
152    ///
153    /// Computes the mean of each sliding window. Supports padding, dilation,
154    /// `count_include_pad` (whether padded zeros count in the denominator),
155    /// and `ceil_mode` (round output size up instead of down).
156    ///
157    /// Stride defaults to `kernel_size` when not specified.
158    ///
159    /// # Examples
160    ///
161    /// Basic 2x2 average pooling:
162    ///
163    /// ```
164    /// # use svod_tensor::Tensor;
165    /// # use ndarray::Array4;
166    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
167    /// let mut y = x.avg_pool2d().kernel_size(&[2, 2]).call().unwrap();
168    /// y.realize().unwrap();
169    /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
170    /// assert_eq!(shape, vec![1, 1, 2, 2]);
171    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![1.0; 4]);
172    /// ```
173    ///
174    /// With explicit stride:
175    ///
176    /// ```
177    /// # use svod_tensor::Tensor;
178    /// # use ndarray::Array4;
179    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
180    /// let y = x.avg_pool2d().kernel_size(&[2, 2]).stride(&[1, 1]).call().unwrap();
181    /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
182    /// assert_eq!(shape, vec![1, 1, 3, 3]);
183    /// ```
184    ///
185    /// With padding and `count_include_pad` disabled:
186    ///
187    /// ```
188    /// # use svod_tensor::Tensor;
189    /// # use ndarray::Array4;
190    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 2, 2), 1.0f32));
191    /// let mut y = x.avg_pool2d()
192    ///     .kernel_size(&[2, 2])
193    ///     .stride(&[1, 1])
194    ///     .padding(&[(1, 1), (1, 1)])
195    ///     .count_include_pad(false)
196    ///     .call()
197    ///     .unwrap();
198    /// y.realize().unwrap();
199    /// // With count_include_pad=false, only non-padded elements count in the average
200    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![1.0; 9]);
201    /// ```
202    #[builder]
203    pub fn avg_pool2d(
204        &self,
205        kernel_size: &[usize],
206        stride: Option<&[usize]>,
207        dilation: Option<&[usize]>,
208        padding: Option<&[(isize, isize)]>,
209        #[builder(default = true)] count_include_pad: bool,
210        #[builder(default = false)] ceil_mode: bool,
211    ) -> Result<Tensor> {
212        let n_spatial = kernel_size.len();
213        let default_dilation: Vec<usize> = vec![1; n_spatial];
214        let stride = stride.unwrap_or(kernel_size);
215        let dilation = dilation.unwrap_or(&default_dilation);
216        let no_pad: Vec<(isize, isize)> = vec![(0, 0); n_spatial];
217        let padding = padding.unwrap_or(&no_pad);
218
219        let reduce_axes: Vec<isize> = (0..n_spatial).map(|j| -(1 + j as isize)).collect();
220        let axes = AxisSpec::Multiple(reduce_axes);
221
222        let shape = self.shape()?;
223        let n_batch = shape.len() - n_spatial;
224        let input_spatial: Vec<SInt> = shape[n_batch..].to_vec();
225
226        let reg_pads = padding.to_vec();
227        let ceil_pads = if ceil_mode {
228            apply_ceil_mode(&reg_pads, &input_spatial, kernel_size, stride, dilation)
229        } else {
230            reg_pads.clone()
231        };
232
233        let pad_and_pool = |x: &Tensor, pads: &[(isize, isize)]| -> Result<Tensor> {
234            let mut out = x.clone();
235            if pads.iter().any(|&(b, e)| b != 0 || e != 0) {
236                let mut full_pad: Vec<(isize, isize)> = vec![(0, 0); n_batch];
237                full_pad.extend_from_slice(pads);
238                out = out.try_pad(&full_pad)?;
239            }
240            out.pool(kernel_size, stride, dilation)
241        };
242
243        if !count_include_pad {
244            // Path 1: sum(pool(x, pads)) / sum(pool(ones, pads))
245            let pads = if ceil_mode { &ceil_pads } else { &reg_pads };
246            let pooled = pad_and_pool(self, pads)?;
247            let sum_x = pooled.sum_with().axes(axes.clone()).keepdim(false).call()?;
248            // Use input dtype for ones tensor (not hardcoded Float32)
249            let dtype = self.uop().dtype();
250            let ones = Tensor::new(UOp::const_(dtype, ConstValue::Float(1.0)));
251            let ones = ones.broadcast_to(&self.shape()?)?;
252            let pooled_ones = pad_and_pool(&ones, pads)?;
253            let sum_ones = pooled_ones.sum_with().axes(axes).keepdim(false).call()?;
254            return sum_x.try_div(&sum_ones);
255        }
256
257        if !ceil_mode {
258            // Path 2: count_include_pad=true, ceil_mode=false → simple mean
259            let pooled = pad_and_pool(self, &reg_pads)?;
260            return pooled.mean(axes);
261        }
262
263        // Path 3: count_include_pad=true, ceil_mode=true
264        // Regular padding counts in the average, but ceil-extra padding does NOT.
265        // Tinygrad: pool(x, ceil_pads).sum / pool(pad(x, reg_pads).ones_like(), ceil-reg).sum
266        let pooled = pad_and_pool(self, &ceil_pads)?;
267        let sum_x = pooled.sum_with().axes(axes.clone()).keepdim(false).call()?;
268
269        // ones_like of the regularly-padded input (all positions are 1, including reg pads),
270        // then pool with only the extra ceil pads (which add zeros that don't count).
271        let mut padded_self = self.clone();
272        if reg_pads.iter().any(|&(b, e)| b != 0 || e != 0) {
273            let mut full_pad: Vec<(isize, isize)> = vec![(0, 0); n_batch];
274            full_pad.extend_from_slice(&reg_pads);
275            padded_self = padded_self.try_pad(&full_pad)?;
276        }
277        let ones_reg = padded_self.one()?;
278        let extra_pads: Vec<(isize, isize)> =
279            ceil_pads.iter().zip(reg_pads.iter()).map(|(c, r)| (c.0 - r.0, c.1 - r.1)).collect();
280        let pooled_ones = pad_and_pool(&ones_reg, &extra_pads)?;
281        let sum_ones = pooled_ones.sum_with().axes(axes).keepdim(false).call()?;
282        sum_x.try_div(&sum_ones)
283    }
284
285    /// Max pooling over spatial dimensions.
286    ///
287    /// Returns the maximum value in each sliding window. Padded positions are
288    /// filled with `-inf` (float) or `i64::MIN` (integer) so they never win.
289    ///
290    /// Stride defaults to `kernel_size` when not specified.
291    ///
292    /// # Examples
293    ///
294    /// Basic 2x2 max pooling:
295    ///
296    /// ```
297    /// # use svod_tensor::Tensor;
298    /// # use ndarray::Array4;
299    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
300    /// let mut y = x.max_pool2d().kernel_size(&[2, 2]).call().unwrap();
301    /// y.realize().unwrap();
302    /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
303    /// assert_eq!(shape, vec![1, 1, 2, 2]);
304    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![1.0; 4]);
305    /// ```
306    ///
307    /// With stride and padding:
308    ///
309    /// ```
310    /// # use svod_tensor::Tensor;
311    /// # use ndarray::Array4;
312    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
313    /// let mut y = x.max_pool2d()
314    ///     .kernel_size(&[3, 3])
315    ///     .stride(&[1, 1])
316    ///     .padding(&[(1, 1), (1, 1)])
317    ///     .call()
318    ///     .unwrap();
319    /// y.realize().unwrap();
320    /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
321    /// assert_eq!(shape, vec![1, 1, 4, 4]);
322    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![1.0; 16]);
323    /// ```
324    #[builder]
325    pub fn max_pool2d(
326        &self,
327        kernel_size: &[usize],
328        stride: Option<&[usize]>,
329        dilation: Option<&[usize]>,
330        padding: Option<&[(isize, isize)]>,
331        #[builder(default = false)] ceil_mode: bool,
332    ) -> Result<Tensor> {
333        let n_spatial = kernel_size.len();
334        let default_dilation: Vec<usize> = vec![1; n_spatial];
335        let stride = stride.unwrap_or(kernel_size);
336        let dilation = dilation.unwrap_or(&default_dilation);
337        let no_pad: Vec<(isize, isize)> = vec![(0, 0); n_spatial];
338        let padding = padding.unwrap_or(&no_pad);
339
340        let pads = if ceil_mode {
341            let shape = self.shape()?;
342            let n_batch = shape.len() - n_spatial;
343            let input_spatial: Vec<SInt> = shape[n_batch..].to_vec();
344            apply_ceil_mode(padding, &input_spatial, kernel_size, stride, dilation)
345        } else {
346            padding.to_vec()
347        };
348
349        let reduce_axes: Vec<isize> = (0..n_spatial).map(|j| -(1 + j as isize)).collect();
350        let axes = AxisSpec::Multiple(reduce_axes);
351
352        let mut x = self.clone();
353        if pads.iter().any(|&(b, e)| b != 0 || e != 0) {
354            let mut full_pad: Vec<(isize, isize)> = vec![(0, 0); self.ndim()? - n_spatial];
355            full_pad.extend_from_slice(&pads);
356            let fill = if self.uop().dtype().is_float() { f64::NEG_INFINITY } else { i64::MIN as f64 };
357            x = x.try_pad_value(&full_pad, fill)?;
358        }
359
360        let pooled = x.pool(kernel_size, stride, dilation)?;
361        pooled.max(axes)
362    }
363
364    /// Max pooling returning both values and flat indices.
365    ///
366    /// Returns `(values, indices)` where indices are flat offsets into the
367    /// input spatial dimensions. Indices can be passed to
368    /// [`max_unpool2d`](Tensor::max_unpool2d) to invert the operation.
369    ///
370    /// Uses a reverse-arange trick (from Tinygrad) to compute first-occurrence
371    /// indices without explicit argmax.
372    ///
373    /// # Examples
374    ///
375    /// ```
376    /// # use svod_tensor::Tensor;
377    /// # use ndarray::Array4;
378    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
379    /// let (mut values, indices) = x.max_pool2d_with_indices()
380    ///     .kernel_size(&[2, 2])
381    ///     .call()
382    ///     .unwrap();
383    /// let _ = indices;
384    /// values.realize().unwrap();
385    /// let shape: Vec<_> = values.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
386    /// assert_eq!(shape, vec![1, 1, 2, 2]);
387    /// assert_eq!(values.as_vec::<f32>().unwrap(), vec![1.0; 4]);
388    /// ```
389    #[builder]
390    pub fn max_pool2d_with_indices(
391        &self,
392        kernel_size: &[usize],
393        stride: Option<&[usize]>,
394        dilation: Option<&[usize]>,
395        padding: Option<&[(isize, isize)]>,
396        #[builder(default = false)] ceil_mode: bool,
397    ) -> Result<(Tensor, Tensor)> {
398        let n_spatial = kernel_size.len();
399        let default_dilation: Vec<usize> = vec![1; n_spatial];
400        let stride = stride.unwrap_or(kernel_size);
401        let dilation = dilation.unwrap_or(&default_dilation);
402        let no_pad: Vec<(isize, isize)> = vec![(0, 0); n_spatial];
403        let padding = padding.unwrap_or(&no_pad);
404
405        let shape = self.shape()?;
406        let n_batch = shape.len() - n_spatial;
407
408        let pads = if ceil_mode {
409            let input_spatial: Vec<SInt> = shape[n_batch..].to_vec();
410            apply_ceil_mode(padding, &input_spatial, kernel_size, stride, dilation)
411        } else {
412            padding.to_vec()
413        };
414
415        let reduce_axes: Vec<isize> = (0..n_spatial).map(|j| -(1 + j as isize)).collect();
416        let axes = AxisSpec::Multiple(reduce_axes.clone());
417
418        // Pool the data with dtype-minimum padding
419        let mut x = self.clone();
420        if pads.iter().any(|&(b, e)| b != 0 || e != 0) {
421            let mut full_pad: Vec<(isize, isize)> = vec![(0, 0); n_batch];
422            full_pad.extend_from_slice(&pads);
423            let fill = if self.uop().dtype().is_float() { f64::NEG_INFINITY } else { i64::MIN as f64 };
424            x = x.try_pad_value(&full_pad, fill)?;
425        }
426        let pooled = x.pool(kernel_size, stride, dilation)?;
427        let values = pooled.max_with().axes(axes.clone()).keepdim(false).call()?;
428
429        // Compute indices using reverse arange trick (Tinygrad approach)
430        let spatial_sz: usize = (0..n_spatial).map(|j| shape[n_batch + j].as_const().unwrap()).product();
431
432        // Create reverse arange: spatial_sz, spatial_sz-1, ..., 1
433        let idx_range = Tensor::arange(spatial_sz as i64, Some(0), Some(-1))?;
434        // Reshape to match spatial dims
435        let spatial_dims: Vec<isize> =
436            (0..n_spatial).map(|j| shape[n_batch + j].as_const().unwrap() as isize).collect();
437        let mut idx_shape: Vec<isize> = vec![1; n_batch];
438        idx_shape.extend_from_slice(&spatial_dims);
439        let idx = idx_range.try_reshape(&idx_shape)?;
440
441        // Pad and pool the index tensor identically
442        let mut idx_padded = idx;
443        if pads.iter().any(|&(b, e)| b != 0 || e != 0) {
444            let mut full_pad: Vec<(isize, isize)> = vec![(0, 0); n_batch];
445            full_pad.extend_from_slice(&pads);
446            idx_padded = idx_padded.try_pad(&full_pad)?;
447        }
448        let pooled_idx = idx_padded.pool(kernel_size, stride, dilation)?;
449
450        // Create mask: pooled == pooled.max(keepdim=True)
451        let pooled_max = pooled.max_with().axes(axes.clone()).keepdim(true).call()?;
452        let mask = pooled.try_eq(&pooled_max)?;
453
454        // Multiply mask * pooled_indices, take max → first-occurrence (via reverse index)
455        let masked_idx = mask.cast(DType::Int32)?.try_mul(&pooled_idx)?;
456        let max_idx = masked_idx.max_with().axes(axes).keepdim(false).call()?;
457
458        // spatial_sz - max_idx → convert reverse index to forward index
459        let sz_t = Tensor::const_(ConstValue::Int(spatial_sz as i64), DType::Int32);
460        let indices = sz_t.try_sub(&max_idx)?;
461
462        Ok((values, indices))
463    }
464
465    /// Inverse of max pooling: scatter pooled values back to their original positions.
466    ///
467    /// Indices are flat offsets into the *inferred* output spatial shape (computed
468    /// from kernel/stride/padding). When `output_size` exceeds the inferred shape,
469    /// the result is zero-padded to match.
470    ///
471    /// Uses one-hot encoding of indices to scatter values: `one_hot(idx) * vals -> sum`.
472    ///
473    /// # Examples
474    ///
475    /// Round-trip with max_pool2d_with_indices:
476    ///
477    /// ```
478    /// # use svod_tensor::Tensor;
479    /// # use ndarray::Array4;
480    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
481    /// let (values, indices) = x.max_pool2d_with_indices()
482    ///     .kernel_size(&[2, 2])
483    ///     .call()
484    ///     .unwrap();
485    /// let unpooled = values.max_unpool2d()
486    ///     .indices(&indices)
487    ///     .kernel_size(&[2, 2])
488    ///     .call()
489    ///     .unwrap();
490    /// let shape: Vec<_> = unpooled.shape().unwrap().iter()
491    ///     .map(|d| d.as_const().unwrap()).collect();
492    /// assert_eq!(shape, vec![1, 1, 4, 4]);
493    /// ```
494    #[builder]
495    pub fn max_unpool2d(
496        &self,
497        indices: &Tensor,
498        kernel_size: &[usize],
499        stride: Option<&[usize]>,
500        padding: Option<&[(isize, isize)]>,
501        output_size: Option<&[usize]>,
502    ) -> Result<Tensor> {
503        let shape = self.shape()?;
504        let ndim = shape.len();
505        let n_spatial = kernel_size.len();
506        let n_batch = ndim - n_spatial;
507
508        let spatial_shape: Vec<usize> = (0..n_spatial).map(|j| shape[n_batch + j].as_const().unwrap()).collect();
509
510        // Inferred shape from inverse pooling formula: o = (i-1)*s - (pB+pA) + k
511        let stride = stride.unwrap_or(kernel_size);
512        let no_pad: Vec<(isize, isize)> = vec![(0, 0); n_spatial];
513        let padding = padding.unwrap_or(&no_pad);
514        let inferred_spatial: Vec<usize> = (0..n_spatial)
515            .map(|j| {
516                let (pa, pb) = padding[j];
517                (spatial_shape[j] - 1) * stride[j] - (pa as usize + pb as usize) + kernel_size[j]
518            })
519            .collect();
520
521        let inferred_numel: usize = inferred_spatial.iter().product();
522        let bs: usize = (0..n_batch).map(|j| shape[j].as_const().unwrap()).product();
523
524        // Flatten: (N, C, *spatial) → (N*C, 1, num_pooled)
525        let num_pooled: usize = spatial_shape.iter().product();
526        let vals_flat = self.try_reshape([bs as isize, 1, num_pooled as isize])?;
527        let idx_flat = indices.try_reshape([bs as isize, 1, num_pooled as isize])?;
528
529        // One-hot: compare indices against arange(inferred_numel)
530        let arange = Tensor::arange(inferred_numel as i64, None, None)?.cast(indices.uop().dtype())?.try_reshape([
531            1,
532            inferred_numel as isize,
533            1,
534        ])?;
535        let one_hot = idx_flat.try_eq(&arange)?;
536
537        // Place values at one-hot positions, zero elsewhere, then sum over pooled dim
538        let zero = Tensor::const_(0.0f64, self.uop().dtype());
539        let placed = vals_flat.where_(&one_hot, &zero)?;
540        let result = placed.sum(-1isize)?;
541
542        // Reshape to (N, C, *inferred_spatial)
543        let batch_dims: Vec<isize> = (0..n_batch).map(|j| shape[j].as_const().unwrap() as isize).collect();
544        let mut inferred_shape: Vec<isize> = batch_dims.clone();
545        inferred_shape.extend(inferred_spatial.iter().map(|&s| s as isize));
546        let result = result.try_reshape(&inferred_shape)?;
547
548        // If output_size is larger, zero-pad to match
549        if let Some(os) = output_size {
550            let out_spatial = &os[os.len() - n_spatial..];
551            if out_spatial != inferred_spatial.as_slice() {
552                let mut pad_spec: Vec<(isize, isize)> = vec![(0, 0); n_batch];
553                for j in 0..n_spatial {
554                    pad_spec.push((0, (out_spatial[j] - inferred_spatial[j]) as isize));
555                }
556                return result.try_pad(&pad_spec);
557            }
558        }
559        Ok(result)
560    }
561
562    /// Col2Im: adjoint of im2col. Reconstructs an image from columns, summing overlaps.
563    ///
564    /// Input shape: `[N, C * prod(block_shape), L]` where `L` is the number of sliding positions.
565    /// Output shape: `[N, C, *image_shape]`.
566    ///
567    /// Uses the adjoint of [`pool`](Tensor::pool): for each kernel position, stride-dilate
568    /// the column data, pad to the correct offset, and accumulate. `O(output_size)` memory,
569    /// `O(bl * output_size)` compute -- no large one-hot intermediates.
570    ///
571    /// # Examples
572    ///
573    /// Reconstruct a 4x4 image from 2x2 blocks with no overlap:
574    ///
575    /// ```
576    /// # use svod_tensor::Tensor;
577    /// # use ndarray::Array3;
578    /// // 1 batch, 1 channel, 2x2 block = 4 cols, 4 sliding positions
579    /// let cols = Tensor::from_ndarray(&Array3::from_elem((1, 4, 4), 1.0f32));
580    /// let mut img = cols.col2im()
581    ///     .image_shape(&[4, 4])
582    ///     .block_shape(&[2, 2])
583    ///     .strides(&[2, 2])
584    ///     .call()
585    ///     .unwrap();
586    /// img.realize().unwrap();
587    /// let shape: Vec<_> = img.shape().unwrap().iter()
588    ///     .map(|d| d.as_const().unwrap()).collect();
589    /// assert_eq!(shape, vec![1, 1, 4, 4]);
590    /// // Non-overlapping blocks of ones reconstruct to all ones
591    /// assert_eq!(img.as_vec::<f32>().unwrap(), vec![1.0; 16]);
592    /// ```
593    #[builder]
594    pub fn col2im(
595        &self,
596        image_shape: &[usize],
597        block_shape: &[usize],
598        strides: Option<&[usize]>,
599        pads: Option<&[(isize, isize)]>,
600        dilations: Option<&[usize]>,
601    ) -> Result<Tensor> {
602        let n_spatial = image_shape.len();
603        let no_strides: Vec<usize> = vec![1; n_spatial];
604        let no_pads: Vec<(isize, isize)> = vec![(0, 0); n_spatial];
605        let no_dilations: Vec<usize> = vec![1; n_spatial];
606        let strides = strides.unwrap_or(&no_strides);
607        let pads = pads.unwrap_or(&no_pads);
608        let dilations = dilations.unwrap_or(&no_dilations);
609
610        let shape = self.shape()?;
611        let n = shape[0].as_const().unwrap();
612        let c_times_bl: usize = shape[1].as_const().unwrap();
613        let bl: usize = block_shape.iter().product();
614        snafu::ensure!(
615            c_times_bl.is_multiple_of(bl),
616            DivisibilitySnafu {
617                op: "col2im",
618                lhs_name: "C*block_size",
619                lhs: c_times_bl,
620                rhs_name: "block_size",
621                rhs: bl
622            }
623        );
624        let c = c_times_bl / bl;
625
626        // Padded image shape (reconstruct in padded space, shrink at end)
627        let padded_img: Vec<usize> =
628            (0..n_spatial).map(|i| (image_shape[i] as isize + pads[i].0 + pads[i].1) as usize).collect();
629
630        // Number of sliding positions per spatial dimension
631        let l_spatial: Vec<usize> = (0..n_spatial)
632            .map(|i| {
633                let effective_k = dilations[i] * (block_shape[i] - 1) + 1;
634                (padded_img[i] - effective_k) / strides[i] + 1
635            })
636            .collect();
637
638        // Reshape input: [N, C*bl, L] → [N*C, *block_shape, *L_spatial]
639        let nc = n * c;
640        let mut data_shape: Vec<isize> = vec![nc as isize];
641        data_shape.extend(block_shape.iter().map(|&s| s as isize));
642        data_shape.extend(l_spatial.iter().map(|&s| s as isize));
643        let data = self.try_reshape(&data_shape)?;
644
645        // Initialize output: [N*C, *padded_img] with zeros
646        let mut out_dims: Vec<usize> = vec![nc];
647        out_dims.extend_from_slice(&padded_img);
648        let mut result = Tensor::full(&out_dims, 0.0f64, self.uop().dtype())?;
649
650        // Iterate over all kernel positions in block_shape
651        for be in 0..bl {
652            // Unravel be → (k0, k1, ..., k_{n-1})
653            let mut kpos = vec![0usize; n_spatial];
654            let mut rem = be;
655            for i in (0..n_spatial).rev() {
656                kpos[i] = rem % block_shape[i];
657                rem /= block_shape[i];
658            }
659
660            // Extract slice for this kernel position: [N*C, *L_spatial]
661            // Shrink block dims (dims 1..n_spatial) to singletons, keep L_spatial dims
662            let mut shrink_ranges: Vec<(isize, isize)> = vec![(0, nc as isize)];
663            for &k in kpos.iter().take(n_spatial) {
664                shrink_ranges.push((k as isize, k as isize + 1));
665            }
666            for &l in l_spatial.iter().take(n_spatial) {
667                shrink_ranges.push((0, l as isize));
668            }
669            let slice = data.try_shrink(&shrink_ranges)?;
670            // Squeeze block dims → [N*C, *L_spatial]
671            let mut sq_shape: Vec<isize> = vec![nc as isize];
672            sq_shape.extend(l_spatial.iter().map(|&s| s as isize));
673            let mut slice = slice.try_reshape(&sq_shape)?;
674
675            // For each spatial dim: stride-dilate L_j, then pad to position
676            for j in 0..n_spatial {
677                let dim = 1 + j;
678                let l_j = l_spatial[j];
679
680                // Stride dilation: insert stride-1 zeros between elements
681                if strides[j] > 1 {
682                    let s = strides[j];
683                    let ndim = slice.shape()?.len();
684                    // [... L_j ...] → [... L_j, 1 ...] → pad → [... L_j, S ...] → [... L_j*S ...] → shrink
685                    let mut sh: Vec<isize> = slice.shape()?.iter().map(|d| d.as_const().unwrap() as isize).collect();
686                    sh.insert(dim + 1, 1);
687                    slice = slice.try_reshape(&sh)?;
688
689                    let mut pad_spec: Vec<(isize, isize)> = vec![(0, 0); ndim + 1];
690                    pad_spec[dim + 1] = (0, (s - 1) as isize);
691                    slice = slice.try_pad(&pad_spec)?;
692
693                    sh[dim] = (l_j * s) as isize;
694                    sh.remove(dim + 1);
695                    slice = slice.try_reshape(&sh)?;
696
697                    let dilated_l = (l_j - 1) * s + 1;
698                    let mut sr: Vec<(isize, isize)> =
699                        slice.shape()?.iter().map(|d| (0, d.as_const().unwrap() as isize)).collect();
700                    sr[dim] = (0, dilated_l as isize);
701                    slice = slice.try_shrink(&sr)?;
702                }
703
704                // Pad for kernel position offset: left = k*d, right = (K-1-k)*d
705                let left = kpos[j] * dilations[j];
706                let right = (block_shape[j] - 1 - kpos[j]) * dilations[j];
707                if left > 0 || right > 0 {
708                    let mut pad_spec: Vec<(isize, isize)> = vec![(0, 0); slice.shape()?.len()];
709                    pad_spec[dim] = (left as isize, right as isize);
710                    slice = slice.try_pad(&pad_spec)?;
711                }
712            }
713
714            result = result.try_add(&slice)?;
715        }
716
717        // Shrink to remove padding → [N*C, *image_shape]
718        let mut shrink_ranges: Vec<(isize, isize)> = vec![(0, nc as isize)];
719        for j in 0..n_spatial {
720            shrink_ranges.push((pads[j].0, pads[j].0 + image_shape[j] as isize));
721        }
722        let result = result.try_shrink(&shrink_ranges)?;
723
724        // Reshape to [N, C, *image_shape]
725        let mut final_shape: Vec<isize> = vec![n as isize, c as isize];
726        final_shape.extend(image_shape.iter().map(|&s| s as isize));
727        result.try_reshape(&final_shape)
728    }
729}