Skip to main content

torsh_functional/sparse/
convolution.rs

1//! Sparse tensor convolution operations
2//!
3//! This module provides convolution operations optimized for sparse tensors,
4//! including 1D and 2D convolutions with support for padding, stride, and dilation.
5
6use crate::sparse::core::SparseTensor;
7use torsh_core::{Result as TorshResult, TorshError};
8use torsh_tensor::Tensor;
9
10/// Sparse 1D convolution
11///
12/// Performs 1D convolution on sparse input tensors with dense kernels.
13/// This is efficient for sparse inputs as it only processes non-zero elements.
14///
15/// # Mathematical Formula
16/// For input x and kernel w:
17/// `y[b, i] = Σ(x[b, i + k*d - p] * w[o, k]) + bias[o]`
18/// where b=batch, i=output position, k=kernel position, d=dilation, p=padding, o=output channel
19///
20/// # Arguments
21/// * `input` - Sparse input tensor \[batch_size, input_length\]
22/// * `weight` - Dense weight tensor \[out_channels, kernel_size\]
23/// * `bias` - Optional bias tensor \[out_channels\]
24/// * `stride` - Convolution stride
25/// * `padding` - Zero padding
26/// * `dilation` - Kernel dilation
27///
28/// # Returns
29/// Sparse output tensor after convolution
30pub fn sparse_conv1d(
31    input: &SparseTensor,
32    weight: &Tensor,
33    bias: Option<&Tensor>,
34    stride: usize,
35    padding: usize,
36    dilation: usize,
37) -> TorshResult<SparseTensor> {
38    if input.ndim != 2 {
39        return Err(TorshError::invalid_argument_with_context(
40            "Input must be 2D tensor [batch_size, input_length]",
41            "sparse_conv1d",
42        ));
43    }
44
45    let weight_shape_binding = weight.shape();
46    let weight_shape = weight_shape_binding.dims();
47    if weight_shape.len() != 2 {
48        return Err(TorshError::invalid_argument_with_context(
49            "Weight must be 2D tensor [out_channels, kernel_size]",
50            "sparse_conv1d",
51        ));
52    }
53
54    let batch_size = input.shape[0];
55    let input_length = input.shape[1];
56    let out_channels = weight_shape[0];
57    let kernel_size = weight_shape[1];
58
59    // Calculate output length
60    let output_length =
61        (input_length + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1;
62
63    let mut result_values = Vec::new();
64    let mut result_indices = Vec::new();
65
66    let input_values = input.values.to_vec()?;
67    let input_indices = input.indices.to_vec()?;
68    let weight_data = weight.to_vec()?;
69
70    // For each non-zero input element
71    for i in 0..input.nnz {
72        let batch_idx = input_indices[i] as usize;
73        let in_pos = input_indices[input.nnz + i] as usize;
74        let input_val = input_values[i];
75
76        // Apply convolution kernel
77        for out_ch in 0..out_channels {
78            for k in 0..kernel_size {
79                let in_idx = in_pos + padding;
80                if in_idx >= k * dilation && (in_idx - k * dilation) % stride == 0 {
81                    let out_pos = (in_idx - k * dilation) / stride;
82                    if out_pos < output_length {
83                        let weight_val = weight_data[out_ch * kernel_size + k];
84                        let conv_val = input_val * weight_val;
85
86                        if conv_val.abs() > 1e-8 {
87                            result_values.push(conv_val);
88                            result_indices.push(batch_idx as f32);
89                            result_indices.push(out_pos as f32);
90                        }
91                    }
92                }
93            }
94        }
95    }
96
97    // Add bias if present
98    if let Some(bias_tensor) = bias {
99        let bias_data = bias_tensor.to_vec()?;
100        for batch in 0..batch_size {
101            for out_ch in 0..out_channels {
102                for pos in 0..output_length {
103                    if bias_data[out_ch].abs() > 1e-8 {
104                        result_values.push(bias_data[out_ch]);
105                        result_indices.push(batch as f32);
106                        result_indices.push(pos as f32);
107                    }
108                }
109            }
110        }
111    }
112
113    let nnz = result_values.len();
114    let values = Tensor::from_data(result_values, vec![nnz], input.values.device())?;
115    let indices = Tensor::from_data(result_indices, vec![2, nnz], input.indices.device())?;
116    let shape = vec![batch_size, output_length];
117
118    let mut result = SparseTensor::new(values, indices, shape)?;
119    result.coalesce()?;
120    Ok(result)
121}
122
123/// Sparse 2D convolution
124///
125/// Performs 2D convolution on sparse input tensors with dense kernels.
126/// Optimized for sparse inputs by only processing non-zero elements.
127///
128/// # Mathematical Formula
129/// For input x and kernel w:
130/// `y[b, o, h, w] = Σ(x[b, i, h + kh*dh - ph, w + kw*dw - pw] * w[o, i, kh, kw]) + bias[o]`
131/// where b=batch, o=output channel, i=input channel, h,w=spatial positions, kh,kw=kernel positions
132///
133/// # Arguments
134/// * `input` - Sparse input tensor \[batch_size, channels, height, width\]
135/// * `weight` - Dense weight tensor \[out_channels, in_channels, kernel_height, kernel_width\]
136/// * `bias` - Optional bias tensor \[out_channels\]
137/// * `stride` - Convolution stride (height, width)
138/// * `padding` - Zero padding (height, width)
139/// * `dilation` - Kernel dilation (height, width)
140///
141/// # Returns
142/// Sparse output tensor after convolution
143pub fn sparse_conv2d(
144    input: &SparseTensor,
145    weight: &Tensor,
146    bias: Option<&Tensor>,
147    stride: (usize, usize),
148    padding: (usize, usize),
149    dilation: (usize, usize),
150) -> TorshResult<SparseTensor> {
151    if input.ndim != 4 {
152        return Err(TorshError::invalid_argument_with_context(
153            "Input must be 4D tensor [batch_size, channels, height, width]",
154            "sparse_conv2d",
155        ));
156    }
157
158    let weight_shape_binding = weight.shape();
159    let weight_shape = weight_shape_binding.dims();
160    if weight_shape.len() != 4 {
161        return Err(TorshError::invalid_argument_with_context(
162            "Weight must be 4D tensor [out_channels, in_channels, kernel_height, kernel_width]",
163            "sparse_conv2d",
164        ));
165    }
166
167    let batch_size = input.shape[0];
168    let in_channels = input.shape[1];
169    let in_height = input.shape[2];
170    let in_width = input.shape[3];
171
172    let out_channels = weight_shape[0];
173    let kernel_h = weight_shape[2];
174    let kernel_w = weight_shape[3];
175
176    // Calculate output dimensions
177    let out_height = (in_height + 2 * padding.0 - dilation.0 * (kernel_h - 1) - 1) / stride.0 + 1;
178    let out_width = (in_width + 2 * padding.1 - dilation.1 * (kernel_w - 1) - 1) / stride.1 + 1;
179
180    let mut result_values = Vec::new();
181    let mut result_indices = Vec::new();
182
183    let input_values = input.values.to_vec()?;
184    let input_indices = input.indices.to_vec()?;
185    let weight_data = weight.to_vec()?;
186
187    // For each non-zero input element
188    for i in 0..input.nnz {
189        let batch_idx = input_indices[i] as usize;
190        let in_ch = input_indices[input.nnz + i] as usize;
191        let in_h = input_indices[2 * input.nnz + i] as usize;
192        let in_w = input_indices[3 * input.nnz + i] as usize;
193        let input_val = input_values[i];
194
195        // Apply convolution kernel
196        for out_ch in 0..out_channels {
197            for kh in 0..kernel_h {
198                for kw in 0..kernel_w {
199                    let h_idx = in_h + padding.0;
200                    let w_idx = in_w + padding.1;
201
202                    if h_idx >= kh * dilation.0
203                        && w_idx >= kw * dilation.1
204                        && (h_idx - kh * dilation.0) % stride.0 == 0
205                        && (w_idx - kw * dilation.1) % stride.1 == 0
206                    {
207                        let out_h = (h_idx - kh * dilation.0) / stride.0;
208                        let out_w = (w_idx - kw * dilation.1) / stride.1;
209
210                        if out_h < out_height && out_w < out_width {
211                            let weight_idx = out_ch * (in_channels * kernel_h * kernel_w)
212                                + in_ch * (kernel_h * kernel_w)
213                                + kh * kernel_w
214                                + kw;
215                            let weight_val = weight_data[weight_idx];
216                            let conv_val = input_val * weight_val;
217
218                            if conv_val.abs() > 1e-8 {
219                                result_values.push(conv_val);
220                                result_indices.push(batch_idx as f32);
221                                result_indices.push(out_ch as f32);
222                                result_indices.push(out_h as f32);
223                                result_indices.push(out_w as f32);
224                            }
225                        }
226                    }
227                }
228            }
229        }
230    }
231
232    // Add bias if present
233    if let Some(bias_tensor) = bias {
234        let bias_data = bias_tensor.to_vec()?;
235        for batch in 0..batch_size {
236            for out_ch in 0..out_channels {
237                for h in 0..out_height {
238                    for w in 0..out_width {
239                        if bias_data[out_ch].abs() > 1e-8 {
240                            result_values.push(bias_data[out_ch]);
241                            result_indices.push(batch as f32);
242                            result_indices.push(out_ch as f32);
243                            result_indices.push(h as f32);
244                            result_indices.push(w as f32);
245                        }
246                    }
247                }
248            }
249        }
250    }
251
252    let nnz = result_values.len();
253    let values = Tensor::from_data(result_values, vec![nnz], input.values.device())?;
254    let indices = Tensor::from_data(result_indices, vec![4, nnz], input.indices.device())?;
255    let shape = vec![batch_size, out_channels, out_height, out_width];
256
257    let mut result = SparseTensor::new(values, indices, shape)?;
258    result.coalesce()?;
259    Ok(result)
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use crate::sparse::core::sparse_coo_tensor;
266
267    #[test]
268    fn test_sparse_conv1d() -> TorshResult<()> {
269        // Create a simple 1D sparse tensor
270        let values = Tensor::from_data(vec![1.0, 2.0], vec![2], torsh_core::DeviceType::Cpu)?;
271        let indices = Tensor::from_data(
272            vec![0.0, 0.0, 1.0, 3.0],
273            vec![2, 2],
274            torsh_core::DeviceType::Cpu,
275        )?;
276        let shape = vec![1, 5]; // batch_size=1, length=5
277
278        let sparse_input = sparse_coo_tensor(&indices, &values, &shape)?;
279
280        // Create a simple weight tensor
281        let weight = Tensor::from_data(vec![0.5, 0.3], vec![1, 2], torsh_core::DeviceType::Cpu)?;
282
283        // Test convolution
284        let result = sparse_conv1d(&sparse_input, &weight, None, 1, 0, 1)?;
285
286        // Verify result shape
287        assert_eq!(result.shape(), &[1, 4]); // output length = 5 - 2 + 1 = 4
288
289        Ok(())
290    }
291
292    #[test]
293    fn test_sparse_conv2d_simple() -> TorshResult<()> {
294        // Create a simple 2D sparse tensor [1, 1, 3, 3] with one non-zero element
295        let values = Tensor::from_data(vec![1.0], vec![1], torsh_core::DeviceType::Cpu)?;
296        let indices = Tensor::from_data(
297            vec![0.0, 0.0, 1.0, 1.0], // [batch=0, channel=0, h=1, w=1]
298            vec![4, 1],
299            torsh_core::DeviceType::Cpu,
300        )?;
301        let shape = vec![1, 1, 3, 3];
302
303        let sparse_input = sparse_coo_tensor(&indices, &values, &shape)?;
304
305        // Create a simple 2x2 kernel
306        let weight = Tensor::from_data(
307            vec![1.0, 2.0, 3.0, 4.0],
308            vec![1, 1, 2, 2], // [out_ch=1, in_ch=1, h=2, w=2]
309            torsh_core::DeviceType::Cpu,
310        )?;
311
312        // Test convolution with stride=1, padding=0
313        let result = sparse_conv2d(&sparse_input, &weight, None, (1, 1), (0, 0), (1, 1))?;
314
315        // Verify result shape: (3-2+1, 3-2+1) = (2, 2)
316        assert_eq!(result.shape(), &[1, 1, 2, 2]);
317
318        Ok(())
319    }
320}