1use crate::sparse::core::SparseTensor;
7use torsh_core::{Result as TorshResult, TorshError};
8use torsh_tensor::Tensor;
9
10pub fn sparse_conv1d(
31 input: &SparseTensor,
32 weight: &Tensor,
33 bias: Option<&Tensor>,
34 stride: usize,
35 padding: usize,
36 dilation: usize,
37) -> TorshResult<SparseTensor> {
38 if input.ndim != 2 {
39 return Err(TorshError::invalid_argument_with_context(
40 "Input must be 2D tensor [batch_size, input_length]",
41 "sparse_conv1d",
42 ));
43 }
44
45 let weight_shape_binding = weight.shape();
46 let weight_shape = weight_shape_binding.dims();
47 if weight_shape.len() != 2 {
48 return Err(TorshError::invalid_argument_with_context(
49 "Weight must be 2D tensor [out_channels, kernel_size]",
50 "sparse_conv1d",
51 ));
52 }
53
54 let batch_size = input.shape[0];
55 let input_length = input.shape[1];
56 let out_channels = weight_shape[0];
57 let kernel_size = weight_shape[1];
58
59 let output_length =
61 (input_length + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1;
62
63 let mut result_values = Vec::new();
64 let mut result_indices = Vec::new();
65
66 let input_values = input.values.to_vec()?;
67 let input_indices = input.indices.to_vec()?;
68 let weight_data = weight.to_vec()?;
69
70 for i in 0..input.nnz {
72 let batch_idx = input_indices[i] as usize;
73 let in_pos = input_indices[input.nnz + i] as usize;
74 let input_val = input_values[i];
75
76 for out_ch in 0..out_channels {
78 for k in 0..kernel_size {
79 let in_idx = in_pos + padding;
80 if in_idx >= k * dilation && (in_idx - k * dilation) % stride == 0 {
81 let out_pos = (in_idx - k * dilation) / stride;
82 if out_pos < output_length {
83 let weight_val = weight_data[out_ch * kernel_size + k];
84 let conv_val = input_val * weight_val;
85
86 if conv_val.abs() > 1e-8 {
87 result_values.push(conv_val);
88 result_indices.push(batch_idx as f32);
89 result_indices.push(out_pos as f32);
90 }
91 }
92 }
93 }
94 }
95 }
96
97 if let Some(bias_tensor) = bias {
99 let bias_data = bias_tensor.to_vec()?;
100 for batch in 0..batch_size {
101 for out_ch in 0..out_channels {
102 for pos in 0..output_length {
103 if bias_data[out_ch].abs() > 1e-8 {
104 result_values.push(bias_data[out_ch]);
105 result_indices.push(batch as f32);
106 result_indices.push(pos as f32);
107 }
108 }
109 }
110 }
111 }
112
113 let nnz = result_values.len();
114 let values = Tensor::from_data(result_values, vec![nnz], input.values.device())?;
115 let indices = Tensor::from_data(result_indices, vec![2, nnz], input.indices.device())?;
116 let shape = vec![batch_size, output_length];
117
118 let mut result = SparseTensor::new(values, indices, shape)?;
119 result.coalesce()?;
120 Ok(result)
121}
122
123pub fn sparse_conv2d(
144 input: &SparseTensor,
145 weight: &Tensor,
146 bias: Option<&Tensor>,
147 stride: (usize, usize),
148 padding: (usize, usize),
149 dilation: (usize, usize),
150) -> TorshResult<SparseTensor> {
151 if input.ndim != 4 {
152 return Err(TorshError::invalid_argument_with_context(
153 "Input must be 4D tensor [batch_size, channels, height, width]",
154 "sparse_conv2d",
155 ));
156 }
157
158 let weight_shape_binding = weight.shape();
159 let weight_shape = weight_shape_binding.dims();
160 if weight_shape.len() != 4 {
161 return Err(TorshError::invalid_argument_with_context(
162 "Weight must be 4D tensor [out_channels, in_channels, kernel_height, kernel_width]",
163 "sparse_conv2d",
164 ));
165 }
166
167 let batch_size = input.shape[0];
168 let in_channels = input.shape[1];
169 let in_height = input.shape[2];
170 let in_width = input.shape[3];
171
172 let out_channels = weight_shape[0];
173 let kernel_h = weight_shape[2];
174 let kernel_w = weight_shape[3];
175
176 let out_height = (in_height + 2 * padding.0 - dilation.0 * (kernel_h - 1) - 1) / stride.0 + 1;
178 let out_width = (in_width + 2 * padding.1 - dilation.1 * (kernel_w - 1) - 1) / stride.1 + 1;
179
180 let mut result_values = Vec::new();
181 let mut result_indices = Vec::new();
182
183 let input_values = input.values.to_vec()?;
184 let input_indices = input.indices.to_vec()?;
185 let weight_data = weight.to_vec()?;
186
187 for i in 0..input.nnz {
189 let batch_idx = input_indices[i] as usize;
190 let in_ch = input_indices[input.nnz + i] as usize;
191 let in_h = input_indices[2 * input.nnz + i] as usize;
192 let in_w = input_indices[3 * input.nnz + i] as usize;
193 let input_val = input_values[i];
194
195 for out_ch in 0..out_channels {
197 for kh in 0..kernel_h {
198 for kw in 0..kernel_w {
199 let h_idx = in_h + padding.0;
200 let w_idx = in_w + padding.1;
201
202 if h_idx >= kh * dilation.0
203 && w_idx >= kw * dilation.1
204 && (h_idx - kh * dilation.0) % stride.0 == 0
205 && (w_idx - kw * dilation.1) % stride.1 == 0
206 {
207 let out_h = (h_idx - kh * dilation.0) / stride.0;
208 let out_w = (w_idx - kw * dilation.1) / stride.1;
209
210 if out_h < out_height && out_w < out_width {
211 let weight_idx = out_ch * (in_channels * kernel_h * kernel_w)
212 + in_ch * (kernel_h * kernel_w)
213 + kh * kernel_w
214 + kw;
215 let weight_val = weight_data[weight_idx];
216 let conv_val = input_val * weight_val;
217
218 if conv_val.abs() > 1e-8 {
219 result_values.push(conv_val);
220 result_indices.push(batch_idx as f32);
221 result_indices.push(out_ch as f32);
222 result_indices.push(out_h as f32);
223 result_indices.push(out_w as f32);
224 }
225 }
226 }
227 }
228 }
229 }
230 }
231
232 if let Some(bias_tensor) = bias {
234 let bias_data = bias_tensor.to_vec()?;
235 for batch in 0..batch_size {
236 for out_ch in 0..out_channels {
237 for h in 0..out_height {
238 for w in 0..out_width {
239 if bias_data[out_ch].abs() > 1e-8 {
240 result_values.push(bias_data[out_ch]);
241 result_indices.push(batch as f32);
242 result_indices.push(out_ch as f32);
243 result_indices.push(h as f32);
244 result_indices.push(w as f32);
245 }
246 }
247 }
248 }
249 }
250 }
251
252 let nnz = result_values.len();
253 let values = Tensor::from_data(result_values, vec![nnz], input.values.device())?;
254 let indices = Tensor::from_data(result_indices, vec![4, nnz], input.indices.device())?;
255 let shape = vec![batch_size, out_channels, out_height, out_width];
256
257 let mut result = SparseTensor::new(values, indices, shape)?;
258 result.coalesce()?;
259 Ok(result)
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use crate::sparse::core::sparse_coo_tensor;
266
267 #[test]
268 fn test_sparse_conv1d() -> TorshResult<()> {
269 let values = Tensor::from_data(vec![1.0, 2.0], vec![2], torsh_core::DeviceType::Cpu)?;
271 let indices = Tensor::from_data(
272 vec![0.0, 0.0, 1.0, 3.0],
273 vec![2, 2],
274 torsh_core::DeviceType::Cpu,
275 )?;
276 let shape = vec![1, 5]; let sparse_input = sparse_coo_tensor(&indices, &values, &shape)?;
279
280 let weight = Tensor::from_data(vec![0.5, 0.3], vec![1, 2], torsh_core::DeviceType::Cpu)?;
282
283 let result = sparse_conv1d(&sparse_input, &weight, None, 1, 0, 1)?;
285
286 assert_eq!(result.shape(), &[1, 4]); Ok(())
290 }
291
292 #[test]
293 fn test_sparse_conv2d_simple() -> TorshResult<()> {
294 let values = Tensor::from_data(vec![1.0], vec![1], torsh_core::DeviceType::Cpu)?;
296 let indices = Tensor::from_data(
297 vec![0.0, 0.0, 1.0, 1.0], vec![4, 1],
299 torsh_core::DeviceType::Cpu,
300 )?;
301 let shape = vec![1, 1, 3, 3];
302
303 let sparse_input = sparse_coo_tensor(&indices, &values, &shape)?;
304
305 let weight = Tensor::from_data(
307 vec![1.0, 2.0, 3.0, 4.0],
308 vec![1, 1, 2, 2], torsh_core::DeviceType::Cpu,
310 )?;
311
312 let result = sparse_conv2d(&sparse_input, &weight, None, (1, 1), (0, 0), (1, 1))?;
314
315 assert_eq!(result.shape(), &[1, 1, 2, 2]);
317
318 Ok(())
319 }
320}