Skip to main content

svod_tensor/nn/
conv.rs

1//! Convolution operations: conv2d, conv_transpose2d.
2
3use bon::bon;
4
5use svod_ir::SInt;
6
7use crate::Tensor;
8use crate::reduce::AxisSpec;
9
10type Result<T> = crate::Result<T>;
11
12#[bon]
13impl Tensor {
14    /// N-d convolution. Input `(N, Cin, *spatial)`, Weight `(Cout, Cin/groups, *kernel)`.
15    ///
16    /// Computes cross-correlation (conv without kernel flip) by extracting sliding
17    /// windows via [`pool`](Tensor::pool), then contracting against the weight tensor.
18    /// Supports grouped convolution, strided/dilated kernels, and asymmetric padding.
19    ///
20    /// # Examples
21    ///
22    /// Basic 2D convolution with uniform data:
23    ///
24    /// ```
25    /// # use svod_tensor::Tensor;
26    /// # use ndarray::Array4;
27    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 5, 5), 1.0f32));
28    /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
29    /// let mut y = x.conv2d().weight(&w).call().unwrap();
30    /// y.realize().unwrap();
31    /// // 3x3 kernel of ones on input of ones => each output element is 9.0
32    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![9.0; 9]);
33    /// ```
34    ///
35    /// With stride:
36    ///
37    /// ```
38    /// # use svod_tensor::Tensor;
39    /// # use ndarray::Array4;
40    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 5, 5), 1.0f32));
41    /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
42    /// let mut y = x.conv2d().weight(&w).stride(&[2, 2]).call().unwrap();
43    /// y.realize().unwrap();
44    /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
45    /// assert_eq!(shape, vec![1, 1, 2, 2]);
46    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![9.0; 4]);
47    /// ```
48    ///
49    /// With padding:
50    ///
51    /// ```
52    /// # use svod_tensor::Tensor;
53    /// # use ndarray::Array4;
54    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
55    /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
56    /// // padding=1 on each side: output matches input spatial dims
57    /// let mut y = x.conv2d().weight(&w).padding(&[(1, 1), (1, 1)]).call().unwrap();
58    /// y.realize().unwrap();
59    /// let vals = y.as_vec::<f32>().unwrap();
60    /// assert_eq!(vals.len(), 9); // 3x3 output
61    /// // Center element sees full 3x3 window of ones = 9.0
62    /// assert_eq!(vals[4], 9.0);
63    /// // Corner element sees 2x2 window = 4.0
64    /// assert_eq!(vals[0], 4.0);
65    /// ```
66    ///
67    /// With bias:
68    ///
69    /// ```
70    /// # use svod_tensor::Tensor;
71    /// # use ndarray::Array4;
72    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
73    /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
74    /// let b = Tensor::from_slice([10.0f32]);
75    /// let mut y = x.conv2d().weight(&w).bias(&b).call().unwrap();
76    /// y.realize().unwrap();
77    /// // Each output element: 9.0 + 10.0 = 19.0
78    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![19.0]);
79    /// ```
80    #[builder]
81    pub fn conv2d(
82        &self,
83        weight: &Tensor,
84        bias: Option<&Tensor>,
85        #[builder(default = 1)] groups: usize,
86        stride: Option<&[usize]>,
87        dilation: Option<&[usize]>,
88        padding: Option<&[(isize, isize)]>,
89        acc_dtype: Option<svod_dtype::DType>,
90    ) -> Result<Tensor> {
91        let x_shape = self.shape()?;
92        let w_shape = weight.shape()?;
93
94        let bs = x_shape[0].clone(); // SInt — concrete or symbolic (Variable batch)
95        let cin_ = x_shape[1].as_const().expect("channel dim must be concrete");
96        let cout = w_shape[0].as_const().expect("cout must be concrete");
97        let cin = w_shape[1].as_const().expect("cin/g must be concrete");
98
99        let hw: Vec<usize> = w_shape[2..].iter().map(|s| s.as_const().expect("kernel dim must be concrete")).collect();
100        let n_spatial = hw.len();
101
102        if x_shape.len() != w_shape.len() {
103            return Err(crate::error::Error::IrConstruction {
104                details: format!("input and weight must have same ndim, got {} and {}", x_shape.len(), w_shape.len()),
105            });
106        }
107        if groups * cin != cin_ {
108            return Err(crate::error::Error::IrConstruction {
109                details: format!("groups*cin/g ({}) != input channels ({cin_})", groups * cin),
110            });
111        }
112
113        let default_ones: Vec<usize> = vec![1; n_spatial];
114        let stride = stride.unwrap_or(&default_ones);
115        let dilation = dilation.unwrap_or(&default_ones);
116        let no_padding: Vec<(isize, isize)> = vec![(0, 0); n_spatial];
117        let padding = padding.unwrap_or(&no_padding);
118
119        let mut x = self.clone();
120        if padding.iter().any(|&(b, e)| b != 0 || e != 0) {
121            let mut full_pad: Vec<(isize, isize)> = vec![(0, 0); 2];
122            full_pad.extend_from_slice(padding);
123            x = x.try_pad(&full_pad)?;
124        }
125
126        x = x.pool(&hw, stride, dilation)?;
127
128        let oyx: Vec<SInt> = {
129            let xs = x.shape()?;
130            xs[2..2 + n_spatial].to_vec()
131        };
132
133        let rcout = cout / groups;
134
135        // Reshape: (bs, groups, cin, 1, *oyx, *hw)
136        let mut reshape_dims: Vec<SInt> = vec![bs.clone(), groups.into(), cin.into(), 1usize.into()];
137        reshape_dims.extend(oyx.iter().cloned());
138        reshape_dims.extend(hw.iter().map(|&k| SInt::from(k)));
139        x = x.try_reshape(&reshape_dims)?;
140
141        // Expand: (bs, groups, cin, rcout, *oyx, *hw)
142        let mut expand_dims: Vec<SInt> = vec![bs.clone(), groups.into(), cin.into(), rcout.into()];
143        expand_dims.extend(oyx.iter().cloned());
144        expand_dims.extend(hw.iter().map(|&k| SInt::from(k)));
145        x = x.try_expand(&expand_dims)?;
146
147        // Permute: (bs, groups, rcout, *oyx, cin, *hw)
148        let mut perm: Vec<isize> = vec![0, 1, 3];
149        for j in 0..n_spatial {
150            perm.push(4 + j as isize);
151        }
152        perm.push(2);
153        for j in 0..n_spatial {
154            perm.push((4 + n_spatial + j) as isize);
155        }
156        x = x.try_permute(&perm)?;
157
158        // Reshape weight: (1, groups, rcout, *[1]*n_spatial, cin, *hw)
159        let mut w_reshape: Vec<isize> = vec![1, groups as isize, rcout as isize];
160        w_reshape.extend(std::iter::repeat_n(1isize, n_spatial));
161        w_reshape.push(cin as isize);
162        w_reshape.extend(hw.iter().map(|&k| k as isize));
163        let w = weight.try_reshape(&w_reshape)?;
164
165        x = x.try_mul(&w)?;
166
167        // Sum over last (1 + n_spatial) dims
168        let total_dims = x.ndim()?;
169        let reduce_axes: Vec<isize> = (0..(1 + n_spatial)).map(|i| (total_dims - 1 - i) as isize).collect();
170        x = x.sum_with().axes(AxisSpec::Multiple(reduce_axes)).keepdim(true).maybe_dtype(acc_dtype).call()?;
171
172        // Reshape to (bs, cout, *oyx)
173        let mut final_shape: Vec<SInt> = vec![bs.clone(), cout.into()];
174        final_shape.extend(oyx.iter().cloned());
175        x = x.try_reshape(&final_shape)?;
176
177        if let Some(bias) = bias {
178            let mut bias_shape: Vec<isize> = vec![1, cout as isize];
179            bias_shape.extend(std::iter::repeat_n(1isize, n_spatial));
180            let bias = bias.try_reshape(&bias_shape)?;
181            x = x.try_add(&bias)?;
182        }
183
184        Ok(x)
185    }
186
187    /// Transposed convolution (fractionally-strided convolution).
188    ///
189    /// Computes the gradient of a forward convolution, commonly used for upsampling.
190    /// Internally flips the kernel, interleaves zeros for stride > 1, computes
191    /// transposed padding, then delegates to [`conv2d`](Tensor::conv2d).
192    ///
193    /// Input `(N, Cin, *spatial)`, Weight `(Cin, Cout/groups, *kernel)`.
194    ///
195    /// # Examples
196    ///
197    /// Basic transposed convolution (upsampling):
198    ///
199    /// ```
200    /// # use svod_tensor::Tensor;
201    /// # use ndarray::Array4;
202    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 2, 2), 1.0f32));
203    /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
204    /// let mut y = x.conv_transpose2d().weight(&w).call().unwrap();
205    /// y.realize().unwrap();
206    /// let vals = y.as_vec::<f32>().unwrap();
207    /// assert_eq!(vals.len(), 16); // 4x4 output
208    /// // Center elements see full overlap of both input positions
209    /// assert_eq!(vals[5], 4.0);
210    /// ```
211    ///
212    /// With stride (stronger upsampling):
213    ///
214    /// ```
215    /// # use svod_tensor::Tensor;
216    /// # use ndarray::Array4;
217    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 2, 2), 1.0f32));
218    /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
219    /// let mut y = x.conv_transpose2d().weight(&w).stride(&[2, 2]).call().unwrap();
220    /// y.realize().unwrap();
221    /// let vals = y.as_vec::<f32>().unwrap();
222    /// assert_eq!(vals.len(), 25); // 5x5 output
223    /// ```
224    ///
225    /// With padding and output padding:
226    ///
227    /// ```
228    /// # use svod_tensor::Tensor;
229    /// # use ndarray::Array4;
230    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 2, 2), 1.0f32));
231    /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
232    /// let mut y = x.conv_transpose2d()
233    ///     .weight(&w)
234    ///     .stride(&[2, 2])
235    ///     .padding(&[(1, 1), (1, 1)])
236    ///     .output_padding(&[1, 1])
237    ///     .call()
238    ///     .unwrap();
239    /// y.realize().unwrap();
240    /// let vals = y.as_vec::<f32>().unwrap();
241    /// assert_eq!(vals.len(), 16); // 4x4 output
242    /// ```
243    #[builder]
244    pub fn conv_transpose2d(
245        &self,
246        weight: &Tensor,
247        bias: Option<&Tensor>,
248        #[builder(default = 1)] groups: usize,
249        stride: Option<&[usize]>,
250        dilation: Option<&[usize]>,
251        padding: Option<&[(isize, isize)]>,
252        output_padding: Option<&[usize]>,
253    ) -> Result<Tensor> {
254        let w_shape = weight.shape()?;
255        let hw: Vec<usize> = w_shape[2..].iter().map(|s| s.as_const().expect("kernel dim must be concrete")).collect();
256        let n_spatial = hw.len();
257
258        let default_ones: Vec<usize> = vec![1; n_spatial];
259        let default_zeros: Vec<usize> = vec![0; n_spatial];
260        let default_no_pad: Vec<(isize, isize)> = vec![(0, 0); n_spatial];
261        let stride = stride.unwrap_or(&default_ones);
262        let dilation = dilation.unwrap_or(&default_ones);
263        let padding = padding.unwrap_or(&default_no_pad);
264        let output_padding = output_padding.unwrap_or(&default_zeros);
265
266        let cout_in = w_shape[0].as_const().unwrap();
267        let cin_g = w_shape[1].as_const().unwrap();
268        let rcout = cout_in / groups;
269
270        // Reshape to (groups, rcout, cin_g, *HW)
271        let mut unflatten_shape: Vec<isize> = vec![groups as isize, rcout as isize, cin_g as isize];
272        unflatten_shape.extend(hw.iter().map(|&k| k as isize));
273        let mut w = weight.try_reshape(&unflatten_shape)?;
274
275        // Transpose dim 1 and 2: (groups, cin_g, rcout, *HW)
276        w = w.try_transpose(1, 2)?;
277
278        // Flip kernel dims
279        let flip_axes: Vec<isize> = (3..(3 + n_spatial) as isize).collect();
280        w = w.flip(&flip_axes)?;
281
282        // Flatten back: (groups * cin_g, rcout, *HW)
283        let mut flat_shape: Vec<isize> = vec![(groups * cin_g) as isize, rcout as isize];
284        flat_shape.extend(hw.iter().map(|&k| k as isize));
285        w = w.try_reshape(&flat_shape)?;
286
287        // Handle stride > 1: interleave zeros across all spatial dims at once.
288        // Matches Tinygrad: (k) -> reshape (k,1) -> pad (k,s) -> reshape (k*s) -> shrink (k-(s-1))
289        // All spatial dims are processed in a single reshape/pad/reshape/shrink sequence
290        // to avoid cascading PAD operations that create exponential boolean condition trees.
291        let mut x = self.clone();
292        if stride.iter().any(|&s| s > 1) {
293            let x_shape = x.shape()?;
294            let spatial: Vec<usize> = x_shape[2..].iter().map(|s| s.as_const().unwrap()).collect();
295
296            // Step 1: reshape (N,C,h,w) -> (N,C,h,1,w,1)
297            let mut rshape: Vec<SInt> = vec![x_shape[0].clone(), x_shape[1].clone()];
298            for &k in &spatial {
299                rshape.push(k.into());
300                rshape.push(1usize.into());
301            }
302            x = x.try_reshape(&rshape)?;
303
304            // Step 2: pad inserted dims by (0, s-1): (N,C,h,s,w,s)
305            let mut pad_spec: Vec<(isize, isize)> = vec![(0, 0); 2];
306            for &s in stride.iter() {
307                pad_spec.push((0, 0));
308                pad_spec.push((0, (s - 1) as isize));
309            }
310            x = x.try_pad(&pad_spec)?;
311
312            // Step 3: reshape to merge pairs: (N,C,h*s,w*s)
313            let x_shape = x.shape()?;
314            let mut rshape: Vec<SInt> = vec![x_shape[0].clone(), x_shape[1].clone()];
315            for j in 0..n_spatial {
316                let a = x_shape[2 + j * 2].as_const().unwrap();
317                let b = x_shape[2 + j * 2 + 1].as_const().unwrap();
318                rshape.push((a * b).into());
319            }
320            x = x.try_reshape(&rshape)?;
321
322            // Step 4: shrink to remove trailing stride-1
323            // Use None for batch/channel dims (pass through).
324            let mut ranges: Vec<Option<(isize, isize)>> = vec![None, None];
325            for j in 0..n_spatial {
326                let new_size = spatial[j] * stride[j] - (stride[j] - 1);
327                ranges.push(Some((0, new_size as isize)));
328            }
329            x = x.try_shrink(&ranges)?;
330        }
331
332        // Compute transposed padding
333        let conv_padding: Vec<(isize, isize)> = (0..n_spatial)
334            .map(|j| {
335                let pb = padding[j].0;
336                let pa = padding[j].1;
337                let begin = (hw[j] as isize - 1) * dilation[j] as isize - pb;
338                let end = (hw[j] as isize - 1) * dilation[j] as isize - pa + output_padding[j] as isize;
339                (begin, end)
340            })
341            .collect();
342
343        x.conv2d().weight(&w).groups(groups).maybe_bias(bias).dilation(dilation).padding(&conv_padding).call()
344    }
345}