svod_tensor/nn/conv.rs
1//! Convolution operations: conv2d, conv_transpose2d.
2
3use bon::bon;
4
5use svod_ir::SInt;
6
7use crate::Tensor;
8use crate::reduce::AxisSpec;
9
10type Result<T> = crate::Result<T>;
11
12#[bon]
13impl Tensor {
14 /// N-d convolution. Input `(N, Cin, *spatial)`, Weight `(Cout, Cin/groups, *kernel)`.
15 ///
16 /// Computes cross-correlation (conv without kernel flip) by extracting sliding
17 /// windows via [`pool`](Tensor::pool), then contracting against the weight tensor.
18 /// Supports grouped convolution, strided/dilated kernels, and asymmetric padding.
19 ///
20 /// # Examples
21 ///
22 /// Basic 2D convolution with uniform data:
23 ///
24 /// ```
25 /// # use svod_tensor::Tensor;
26 /// # use ndarray::Array4;
27 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 5, 5), 1.0f32));
28 /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
29 /// let mut y = x.conv2d().weight(&w).call().unwrap();
30 /// y.realize().unwrap();
31 /// // 3x3 kernel of ones on input of ones => each output element is 9.0
32 /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![9.0; 9]);
33 /// ```
34 ///
35 /// With stride:
36 ///
37 /// ```
38 /// # use svod_tensor::Tensor;
39 /// # use ndarray::Array4;
40 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 5, 5), 1.0f32));
41 /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
42 /// let mut y = x.conv2d().weight(&w).stride(&[2, 2]).call().unwrap();
43 /// y.realize().unwrap();
44 /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
45 /// assert_eq!(shape, vec![1, 1, 2, 2]);
46 /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![9.0; 4]);
47 /// ```
48 ///
49 /// With padding:
50 ///
51 /// ```
52 /// # use svod_tensor::Tensor;
53 /// # use ndarray::Array4;
54 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
55 /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
56 /// // padding=1 on each side: output matches input spatial dims
57 /// let mut y = x.conv2d().weight(&w).padding(&[(1, 1), (1, 1)]).call().unwrap();
58 /// y.realize().unwrap();
59 /// let vals = y.as_vec::<f32>().unwrap();
60 /// assert_eq!(vals.len(), 9); // 3x3 output
61 /// // Center element sees full 3x3 window of ones = 9.0
62 /// assert_eq!(vals[4], 9.0);
63 /// // Corner element sees 2x2 window = 4.0
64 /// assert_eq!(vals[0], 4.0);
65 /// ```
66 ///
67 /// With bias:
68 ///
69 /// ```
70 /// # use svod_tensor::Tensor;
71 /// # use ndarray::Array4;
72 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
73 /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
74 /// let b = Tensor::from_slice([10.0f32]);
75 /// let mut y = x.conv2d().weight(&w).bias(&b).call().unwrap();
76 /// y.realize().unwrap();
77 /// // Each output element: 9.0 + 10.0 = 19.0
78 /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![19.0]);
79 /// ```
80 #[builder]
81 pub fn conv2d(
82 &self,
83 weight: &Tensor,
84 bias: Option<&Tensor>,
85 #[builder(default = 1)] groups: usize,
86 stride: Option<&[usize]>,
87 dilation: Option<&[usize]>,
88 padding: Option<&[(isize, isize)]>,
89 acc_dtype: Option<svod_dtype::DType>,
90 ) -> Result<Tensor> {
91 let x_shape = self.shape()?;
92 let w_shape = weight.shape()?;
93
94 let bs = x_shape[0].clone(); // SInt — concrete or symbolic (Variable batch)
95 let cin_ = x_shape[1].as_const().expect("channel dim must be concrete");
96 let cout = w_shape[0].as_const().expect("cout must be concrete");
97 let cin = w_shape[1].as_const().expect("cin/g must be concrete");
98
99 let hw: Vec<usize> = w_shape[2..].iter().map(|s| s.as_const().expect("kernel dim must be concrete")).collect();
100 let n_spatial = hw.len();
101
102 if x_shape.len() != w_shape.len() {
103 return Err(crate::error::Error::IrConstruction {
104 details: format!("input and weight must have same ndim, got {} and {}", x_shape.len(), w_shape.len()),
105 });
106 }
107 if groups * cin != cin_ {
108 return Err(crate::error::Error::IrConstruction {
109 details: format!("groups*cin/g ({}) != input channels ({cin_})", groups * cin),
110 });
111 }
112
113 let default_ones: Vec<usize> = vec![1; n_spatial];
114 let stride = stride.unwrap_or(&default_ones);
115 let dilation = dilation.unwrap_or(&default_ones);
116 let no_padding: Vec<(isize, isize)> = vec![(0, 0); n_spatial];
117 let padding = padding.unwrap_or(&no_padding);
118
119 let mut x = self.clone();
120 if padding.iter().any(|&(b, e)| b != 0 || e != 0) {
121 let mut full_pad: Vec<(isize, isize)> = vec![(0, 0); 2];
122 full_pad.extend_from_slice(padding);
123 x = x.try_pad(&full_pad)?;
124 }
125
126 x = x.pool(&hw, stride, dilation)?;
127
128 let oyx: Vec<SInt> = {
129 let xs = x.shape()?;
130 xs[2..2 + n_spatial].to_vec()
131 };
132
133 let rcout = cout / groups;
134
135 // Reshape: (bs, groups, cin, 1, *oyx, *hw)
136 let mut reshape_dims: Vec<SInt> = vec![bs.clone(), groups.into(), cin.into(), 1usize.into()];
137 reshape_dims.extend(oyx.iter().cloned());
138 reshape_dims.extend(hw.iter().map(|&k| SInt::from(k)));
139 x = x.try_reshape(&reshape_dims)?;
140
141 // Expand: (bs, groups, cin, rcout, *oyx, *hw)
142 let mut expand_dims: Vec<SInt> = vec![bs.clone(), groups.into(), cin.into(), rcout.into()];
143 expand_dims.extend(oyx.iter().cloned());
144 expand_dims.extend(hw.iter().map(|&k| SInt::from(k)));
145 x = x.try_expand(&expand_dims)?;
146
147 // Permute: (bs, groups, rcout, *oyx, cin, *hw)
148 let mut perm: Vec<isize> = vec![0, 1, 3];
149 for j in 0..n_spatial {
150 perm.push(4 + j as isize);
151 }
152 perm.push(2);
153 for j in 0..n_spatial {
154 perm.push((4 + n_spatial + j) as isize);
155 }
156 x = x.try_permute(&perm)?;
157
158 // Reshape weight: (1, groups, rcout, *[1]*n_spatial, cin, *hw)
159 let mut w_reshape: Vec<isize> = vec![1, groups as isize, rcout as isize];
160 w_reshape.extend(std::iter::repeat_n(1isize, n_spatial));
161 w_reshape.push(cin as isize);
162 w_reshape.extend(hw.iter().map(|&k| k as isize));
163 let w = weight.try_reshape(&w_reshape)?;
164
165 x = x.try_mul(&w)?;
166
167 // Sum over last (1 + n_spatial) dims
168 let total_dims = x.ndim()?;
169 let reduce_axes: Vec<isize> = (0..(1 + n_spatial)).map(|i| (total_dims - 1 - i) as isize).collect();
170 x = x.sum_with().axes(AxisSpec::Multiple(reduce_axes)).keepdim(true).maybe_dtype(acc_dtype).call()?;
171
172 // Reshape to (bs, cout, *oyx)
173 let mut final_shape: Vec<SInt> = vec![bs.clone(), cout.into()];
174 final_shape.extend(oyx.iter().cloned());
175 x = x.try_reshape(&final_shape)?;
176
177 if let Some(bias) = bias {
178 let mut bias_shape: Vec<isize> = vec![1, cout as isize];
179 bias_shape.extend(std::iter::repeat_n(1isize, n_spatial));
180 let bias = bias.try_reshape(&bias_shape)?;
181 x = x.try_add(&bias)?;
182 }
183
184 Ok(x)
185 }
186
187 /// Transposed convolution (fractionally-strided convolution).
188 ///
189 /// Computes the gradient of a forward convolution, commonly used for upsampling.
190 /// Internally flips the kernel, interleaves zeros for stride > 1, computes
191 /// transposed padding, then delegates to [`conv2d`](Tensor::conv2d).
192 ///
193 /// Input `(N, Cin, *spatial)`, Weight `(Cin, Cout/groups, *kernel)`.
194 ///
195 /// # Examples
196 ///
197 /// Basic transposed convolution (upsampling):
198 ///
199 /// ```
200 /// # use svod_tensor::Tensor;
201 /// # use ndarray::Array4;
202 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 2, 2), 1.0f32));
203 /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
204 /// let mut y = x.conv_transpose2d().weight(&w).call().unwrap();
205 /// y.realize().unwrap();
206 /// let vals = y.as_vec::<f32>().unwrap();
207 /// assert_eq!(vals.len(), 16); // 4x4 output
208 /// // Center elements see full overlap of both input positions
209 /// assert_eq!(vals[5], 4.0);
210 /// ```
211 ///
212 /// With stride (stronger upsampling):
213 ///
214 /// ```
215 /// # use svod_tensor::Tensor;
216 /// # use ndarray::Array4;
217 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 2, 2), 1.0f32));
218 /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
219 /// let mut y = x.conv_transpose2d().weight(&w).stride(&[2, 2]).call().unwrap();
220 /// y.realize().unwrap();
221 /// let vals = y.as_vec::<f32>().unwrap();
222 /// assert_eq!(vals.len(), 25); // 5x5 output
223 /// ```
224 ///
225 /// With padding and output padding:
226 ///
227 /// ```
228 /// # use svod_tensor::Tensor;
229 /// # use ndarray::Array4;
230 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 2, 2), 1.0f32));
231 /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
232 /// let mut y = x.conv_transpose2d()
233 /// .weight(&w)
234 /// .stride(&[2, 2])
235 /// .padding(&[(1, 1), (1, 1)])
236 /// .output_padding(&[1, 1])
237 /// .call()
238 /// .unwrap();
239 /// y.realize().unwrap();
240 /// let vals = y.as_vec::<f32>().unwrap();
241 /// assert_eq!(vals.len(), 16); // 4x4 output
242 /// ```
243 #[builder]
244 pub fn conv_transpose2d(
245 &self,
246 weight: &Tensor,
247 bias: Option<&Tensor>,
248 #[builder(default = 1)] groups: usize,
249 stride: Option<&[usize]>,
250 dilation: Option<&[usize]>,
251 padding: Option<&[(isize, isize)]>,
252 output_padding: Option<&[usize]>,
253 ) -> Result<Tensor> {
254 let w_shape = weight.shape()?;
255 let hw: Vec<usize> = w_shape[2..].iter().map(|s| s.as_const().expect("kernel dim must be concrete")).collect();
256 let n_spatial = hw.len();
257
258 let default_ones: Vec<usize> = vec![1; n_spatial];
259 let default_zeros: Vec<usize> = vec![0; n_spatial];
260 let default_no_pad: Vec<(isize, isize)> = vec![(0, 0); n_spatial];
261 let stride = stride.unwrap_or(&default_ones);
262 let dilation = dilation.unwrap_or(&default_ones);
263 let padding = padding.unwrap_or(&default_no_pad);
264 let output_padding = output_padding.unwrap_or(&default_zeros);
265
266 let cout_in = w_shape[0].as_const().unwrap();
267 let cin_g = w_shape[1].as_const().unwrap();
268 let rcout = cout_in / groups;
269
270 // Reshape to (groups, rcout, cin_g, *HW)
271 let mut unflatten_shape: Vec<isize> = vec![groups as isize, rcout as isize, cin_g as isize];
272 unflatten_shape.extend(hw.iter().map(|&k| k as isize));
273 let mut w = weight.try_reshape(&unflatten_shape)?;
274
275 // Transpose dim 1 and 2: (groups, cin_g, rcout, *HW)
276 w = w.try_transpose(1, 2)?;
277
278 // Flip kernel dims
279 let flip_axes: Vec<isize> = (3..(3 + n_spatial) as isize).collect();
280 w = w.flip(&flip_axes)?;
281
282 // Flatten back: (groups * cin_g, rcout, *HW)
283 let mut flat_shape: Vec<isize> = vec![(groups * cin_g) as isize, rcout as isize];
284 flat_shape.extend(hw.iter().map(|&k| k as isize));
285 w = w.try_reshape(&flat_shape)?;
286
287 // Handle stride > 1: interleave zeros across all spatial dims at once.
288 // Matches Tinygrad: (k) -> reshape (k,1) -> pad (k,s) -> reshape (k*s) -> shrink (k-(s-1))
289 // All spatial dims are processed in a single reshape/pad/reshape/shrink sequence
290 // to avoid cascading PAD operations that create exponential boolean condition trees.
291 let mut x = self.clone();
292 if stride.iter().any(|&s| s > 1) {
293 let x_shape = x.shape()?;
294 let spatial: Vec<usize> = x_shape[2..].iter().map(|s| s.as_const().unwrap()).collect();
295
296 // Step 1: reshape (N,C,h,w) -> (N,C,h,1,w,1)
297 let mut rshape: Vec<SInt> = vec![x_shape[0].clone(), x_shape[1].clone()];
298 for &k in &spatial {
299 rshape.push(k.into());
300 rshape.push(1usize.into());
301 }
302 x = x.try_reshape(&rshape)?;
303
304 // Step 2: pad inserted dims by (0, s-1): (N,C,h,s,w,s)
305 let mut pad_spec: Vec<(isize, isize)> = vec![(0, 0); 2];
306 for &s in stride.iter() {
307 pad_spec.push((0, 0));
308 pad_spec.push((0, (s - 1) as isize));
309 }
310 x = x.try_pad(&pad_spec)?;
311
312 // Step 3: reshape to merge pairs: (N,C,h*s,w*s)
313 let x_shape = x.shape()?;
314 let mut rshape: Vec<SInt> = vec![x_shape[0].clone(), x_shape[1].clone()];
315 for j in 0..n_spatial {
316 let a = x_shape[2 + j * 2].as_const().unwrap();
317 let b = x_shape[2 + j * 2 + 1].as_const().unwrap();
318 rshape.push((a * b).into());
319 }
320 x = x.try_reshape(&rshape)?;
321
322 // Step 4: shrink to remove trailing stride-1
323 // Use None for batch/channel dims (pass through).
324 let mut ranges: Vec<Option<(isize, isize)>> = vec![None, None];
325 for j in 0..n_spatial {
326 let new_size = spatial[j] * stride[j] - (stride[j] - 1);
327 ranges.push(Some((0, new_size as isize)));
328 }
329 x = x.try_shrink(&ranges)?;
330 }
331
332 // Compute transposed padding
333 let conv_padding: Vec<(isize, isize)> = (0..n_spatial)
334 .map(|j| {
335 let pb = padding[j].0;
336 let pa = padding[j].1;
337 let begin = (hw[j] as isize - 1) * dilation[j] as isize - pb;
338 let end = (hw[j] as isize - 1) * dilation[j] as isize - pa + output_padding[j] as isize;
339 (begin, end)
340 })
341 .collect();
342
343 x.conv2d().weight(&w).groups(groups).maybe_bias(bias).dilation(dilation).padding(&conv_padding).call()
344 }
345}