1use std::collections::HashMap;
7use torsh_core::{Result as TorshResult, TorshError};
8use torsh_tensor::Tensor;
9
10#[derive(Debug, Clone)]
12pub struct SparseTensor {
13 pub values: Tensor,
15 pub indices: Tensor,
17 pub shape: Vec<usize>,
19 pub ndim: usize,
21 pub nnz: usize,
23 pub is_coalesced: bool,
25}
26
27impl SparseTensor {
28 pub fn new(values: Tensor, indices: Tensor, shape: Vec<usize>) -> TorshResult<Self> {
30 let values_shape = values.shape().dims().to_vec();
31 let indices_shape = indices.shape().dims().to_vec();
32
33 if values_shape.len() != 1 {
34 return Err(TorshError::invalid_argument_with_context(
35 "Values must be a 1D tensor",
36 "SparseTensor::new",
37 ));
38 }
39
40 if indices_shape.len() != 2 {
41 return Err(TorshError::invalid_argument_with_context(
42 "Indices must be a 2D tensor",
43 "SparseTensor::new",
44 ));
45 }
46
47 let nnz = values_shape[0];
48 let ndim = shape.len();
49
50 if indices_shape[0] != ndim {
51 return Err(TorshError::invalid_argument_with_context(
52 "Indices first dimension must equal tensor ndim",
53 "SparseTensor::new",
54 ));
55 }
56
57 if indices_shape[1] != nnz {
58 return Err(TorshError::invalid_argument_with_context(
59 "Indices second dimension must equal number of values",
60 "SparseTensor::new",
61 ));
62 }
63
64 Ok(SparseTensor {
65 values,
66 indices,
67 shape,
68 ndim,
69 nnz,
70 is_coalesced: false,
71 })
72 }
73
74 pub fn from_dense(dense: &Tensor) -> TorshResult<Self> {
76 let shape = dense.shape().dims().to_vec();
77 let ndim = shape.len();
78
79 let dense_data = dense.to_vec()?;
81 let mut values_vec = Vec::new();
82 let mut coords_vec = Vec::new(); let total_elements: usize = shape.iter().product();
86 for flat_idx in 0..total_elements {
87 let value = dense_data[flat_idx];
88 if value.abs() > 1e-8 {
89 values_vec.push(value);
91
92 let mut remaining = flat_idx;
94 let mut coords = Vec::with_capacity(ndim);
95 for &dim_size in shape.iter().rev() {
96 coords.push(remaining % dim_size);
97 remaining /= dim_size;
98 }
99 coords.reverse();
100
101 coords_vec.push(coords);
102 }
103 }
104
105 let nnz = values_vec.len();
106
107 let mut indices_vec = Vec::with_capacity(ndim * nnz);
110 for dim in 0..ndim {
111 for coords in &coords_vec {
112 indices_vec.push(coords[dim] as f32);
113 }
114 }
115
116 let values = Tensor::from_data(values_vec, vec![nnz], dense.device())?;
117 let indices = Tensor::from_data(indices_vec, vec![ndim, nnz], dense.device())?;
118
119 Ok(SparseTensor {
120 values,
121 indices,
122 shape,
123 ndim,
124 nnz,
125 is_coalesced: false,
126 })
127 }
128
129 pub fn to_dense(&self) -> TorshResult<Tensor> {
131 let total_elements: usize = self.shape.iter().product();
132 let mut dense_data = vec![0.0f32; total_elements];
133
134 let values_data = self.values.to_vec()?;
135 let indices_data = self.indices.to_vec()?;
136
137 for i in 0..self.nnz {
138 let mut flat_idx = 0;
140 let mut stride = 1;
141
142 for j in (0..self.ndim).rev() {
143 let coord = indices_data[j * self.nnz + i] as usize;
144 flat_idx += coord * stride;
145 stride *= self.shape[j];
146 }
147
148 dense_data[flat_idx] = values_data[i];
149 }
150
151 Tensor::from_data(dense_data, self.shape.clone(), self.values.device())
152 }
153
154 pub fn coalesce(&mut self) -> TorshResult<()> {
156 if self.is_coalesced {
157 return Ok(());
158 }
159
160 let values_data = self.values.to_vec()?;
161 let indices_data = self.indices.to_vec()?;
162
163 let mut index_to_value: HashMap<Vec<usize>, f32> = HashMap::new();
165
166 for i in 0..self.nnz {
167 let mut coords = Vec::with_capacity(self.ndim);
168 for j in 0..self.ndim {
169 coords.push(indices_data[j * self.nnz + i] as usize);
170 }
171
172 *index_to_value.entry(coords).or_insert(0.0) += values_data[i];
173 }
174
175 let mut new_values = Vec::new();
177 let mut new_indices = Vec::new();
178
179 for (coords, value) in index_to_value {
180 if value.abs() > 1e-8 {
181 new_values.push(value);
182 for coord in coords {
183 new_indices.push(coord as f32);
184 }
185 }
186 }
187
188 let new_nnz = new_values.len();
189 self.values = Tensor::from_data(new_values, vec![new_nnz], self.values.device())?;
190 self.indices =
191 Tensor::from_data(new_indices, vec![self.ndim, new_nnz], self.indices.device())?;
192 self.nnz = new_nnz;
193 self.is_coalesced = true;
194
195 Ok(())
196 }
197
198 pub fn nnz(&self) -> usize {
200 self.nnz
201 }
202
203 pub fn shape(&self) -> &[usize] {
205 &self.shape
206 }
207
208 pub fn ndim(&self) -> usize {
210 self.ndim
211 }
212
213 pub fn is_coalesced(&self) -> bool {
215 self.is_coalesced
216 }
217}
218
219pub fn sparse_coo_tensor(
221 indices: &Tensor,
222 values: &Tensor,
223 shape: &[usize],
224) -> TorshResult<SparseTensor> {
225 SparseTensor::new(values.clone(), indices.clone(), shape.to_vec())
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn test_sparse_tensor_creation() -> TorshResult<()> {
234 let values = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3], torsh_core::DeviceType::Cpu)?;
235 let indices = Tensor::from_data(
236 vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0],
237 vec![2, 3],
238 torsh_core::DeviceType::Cpu,
239 )?;
240 let shape = vec![3, 3];
241
242 let sparse = SparseTensor::new(values, indices, shape)?;
243 assert_eq!(sparse.nnz(), 3);
244 assert_eq!(sparse.shape(), &[3, 3]);
245 assert_eq!(sparse.ndim(), 2);
246
247 Ok(())
248 }
249
250 #[test]
251 fn test_sparse_to_dense() -> TorshResult<()> {
252 let values = Tensor::from_data(vec![1.0, 2.0], vec![2], torsh_core::DeviceType::Cpu)?;
253 let indices = Tensor::from_data(
254 vec![0.0, 1.0, 0.0, 1.0],
255 vec![2, 2],
256 torsh_core::DeviceType::Cpu,
257 )?;
258 let shape = vec![2, 2];
259
260 let sparse = SparseTensor::new(values, indices, shape)?;
261 let dense = sparse.to_dense()?;
262
263 let expected_data = vec![1.0, 0.0, 0.0, 2.0];
264 let dense_data = dense.to_vec()?;
265
266 for (actual, expected) in dense_data.iter().zip(expected_data.iter()) {
267 assert!((actual - expected).abs() < 1e-6);
268 }
269
270 Ok(())
271 }
272}