1use crate::Scalar;
8use crate::error::{CoreError, Result};
9use crate::tensor::Tensor;
10
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
31pub struct SparseTensor<T: Scalar> {
32 indices: Vec<Vec<usize>>,
35 values: Vec<T>,
37 shape: Vec<usize>,
39 nnz: usize,
41}
42
43impl<T: Scalar> SparseTensor<T> {
44 pub fn new(indices: Vec<Vec<usize>>, values: Vec<T>, shape: Vec<usize>) -> Result<Self> {
59 let ndim = shape.len();
60 if indices.len() != ndim {
61 return Err(CoreError::InvalidArgument {
62 reason: "indices length must equal number of dimensions",
63 });
64 }
65 let nnz = values.len();
66 for (dim, idx_vec) in indices.iter().enumerate() {
67 if idx_vec.len() != nnz {
68 return Err(CoreError::InvalidArgument {
69 reason: "all index vectors must have the same length as values",
70 });
71 }
72 for &idx in idx_vec {
73 if idx >= shape[dim] {
74 return Err(CoreError::IndexOutOfBounds {
75 index: vec![idx],
76 shape: shape.clone(),
77 });
78 }
79 }
80 }
81 Ok(Self {
82 indices,
83 values,
84 shape,
85 nnz,
86 })
87 }
88
89 pub fn from_dense(tensor: &Tensor<T>) -> Self {
101 let shape = tensor.shape().to_vec();
102 let ndim = shape.len();
103 let strides = tensor.strides().to_vec();
104
105 let mut indices: Vec<Vec<usize>> = vec![Vec::new(); ndim];
106 let mut values = Vec::new();
107
108 for (flat, &val) in tensor.as_slice().iter().enumerate() {
109 if val != T::zero() {
110 let mut remaining = flat;
112 for (dim, &stride) in strides.iter().enumerate() {
113 let coord = if stride == 0 { 0 } else { remaining / stride };
114 indices[dim].push(coord);
115 if stride != 0 {
116 remaining %= stride;
117 }
118 }
119 values.push(val);
120 }
121 }
122
123 let nnz = values.len();
124 Self {
125 indices,
126 values,
127 shape,
128 nnz,
129 }
130 }
131
132 pub fn zeros(shape: Vec<usize>) -> Self {
143 let ndim = shape.len();
144 Self {
145 indices: vec![Vec::new(); ndim],
146 values: Vec::new(),
147 shape,
148 nnz: 0,
149 }
150 }
151
152 #[inline]
158 pub fn nnz(&self) -> usize {
159 self.nnz
160 }
161
162 #[inline]
164 pub fn shape(&self) -> &[usize] {
165 &self.shape
166 }
167
168 #[inline]
170 pub fn ndim(&self) -> usize {
171 self.shape.len()
172 }
173
174 pub fn density(&self) -> f64 {
178 let total: usize = self.shape.iter().product();
179 if total == 0 {
180 return 0.0;
181 }
182 self.nnz as f64 / total as f64
183 }
184
185 #[inline]
187 pub fn values(&self) -> &[T] {
188 &self.values
189 }
190
191 #[inline]
193 pub fn indices(&self) -> &[Vec<usize>] {
194 &self.indices
195 }
196
197 pub fn to_dense(&self) -> Result<Tensor<T>> {
211 let mut dense = Tensor::zeros(self.shape.clone());
212 let strides = dense.strides().to_vec();
213 for k in 0..self.nnz {
214 let flat: usize = self
215 .indices
216 .iter()
217 .enumerate()
218 .map(|(dim, idx_vec)| idx_vec[k] * strides[dim])
219 .sum();
220 let data = dense.as_mut_slice();
221 data[flat] += self.values[k];
222 }
223 Ok(dense)
224 }
225
226 pub fn add(&self, other: &SparseTensor<T>) -> Result<SparseTensor<T>> {
239 if self.shape != other.shape {
240 return Err(CoreError::DimensionMismatch {
241 expected: self.shape.clone(),
242 got: other.shape.clone(),
243 });
244 }
245 let new_nnz = self.nnz + other.nnz;
246 let mut indices: Vec<Vec<usize>> = self
247 .indices
248 .iter()
249 .zip(other.indices.iter())
250 .map(|(a, b)| {
251 let mut v = a.clone();
252 v.extend_from_slice(b);
253 v
254 })
255 .collect();
256 if indices.is_empty() && self.shape.is_empty() {
258 indices = Vec::new();
259 }
260 let mut values = self.values.clone();
261 values.extend_from_slice(&other.values);
262 Ok(SparseTensor {
263 indices,
264 values,
265 shape: self.shape.clone(),
266 nnz: new_nnz,
267 })
268 }
269
270 pub fn scalar_mul(&self, scalar: T) -> SparseTensor<T> {
285 SparseTensor {
286 indices: self.indices.clone(),
287 values: self.values.iter().map(|&v| v * scalar).collect(),
288 shape: self.shape.clone(),
289 nnz: self.nnz,
290 }
291 }
292
293 pub fn sparse_matmul(&self, other: &SparseTensor<T>) -> Result<SparseTensor<T>> {
307 if self.ndim() != 2 || other.ndim() != 2 {
308 return Err(CoreError::InvalidArgument {
309 reason: "sparse_matmul requires 2D tensors",
310 });
311 }
312 let m = self.shape[0];
313 let k_self = self.shape[1];
314 let k_other = other.shape[0];
315 let n = other.shape[1];
316 if k_self != k_other {
317 return Err(CoreError::DimensionMismatch {
318 expected: self.shape.clone(),
319 got: other.shape.clone(),
320 });
321 }
322
323 let mut other_row_map: HashMap<usize, Vec<(usize, T)>> = HashMap::new();
327 for k in 0..other.nnz {
328 other_row_map
329 .entry(other.indices[0][k])
330 .or_default()
331 .push((other.indices[1][k], other.values[k]));
332 }
333
334 let mut result_map: HashMap<(usize, usize), T> = HashMap::new();
335 for k in 0..self.nnz {
336 let row = self.indices[0][k];
337 let col = self.indices[1][k];
338 let val = self.values[k];
339 if let Some(entries) = other_row_map.get(&col) {
340 for &(other_col, other_val) in entries {
341 let entry = result_map.entry((row, other_col)).or_insert(T::zero());
342 *entry += val * other_val;
343 }
344 }
345 }
346
347 let mut row_indices = Vec::new();
349 let mut col_indices = Vec::new();
350 let mut values = Vec::new();
351 for (&(r, c), &v) in &result_map {
352 if v != T::zero() {
353 row_indices.push(r);
354 col_indices.push(c);
355 values.push(v);
356 }
357 }
358 let nnz = values.len();
359 Ok(SparseTensor {
360 indices: vec![row_indices, col_indices],
361 values,
362 shape: vec![m, n],
363 nnz,
364 })
365 }
366
367 pub fn sparse_dense_matmul(&self, dense: &Tensor<T>) -> Result<Tensor<T>> {
377 if self.ndim() != 2 {
378 return Err(CoreError::InvalidArgument {
379 reason: "sparse_dense_matmul requires a 2D sparse tensor",
380 });
381 }
382 if dense.ndim() != 2 {
383 return Err(CoreError::InvalidArgument {
384 reason: "sparse_dense_matmul requires a 2D dense tensor",
385 });
386 }
387 let m = self.shape[0];
388 let k_self = self.shape[1];
389 let k_dense = dense.shape()[0];
390 let n = dense.shape()[1];
391 if k_self != k_dense {
392 return Err(CoreError::DimensionMismatch {
393 expected: self.shape.clone(),
394 got: dense.shape().to_vec(),
395 });
396 }
397
398 let mut result = Tensor::zeros(vec![m, n]);
399 let result_data = result.as_mut_slice();
400 let dense_data = dense.as_slice();
401
402 for k in 0..self.nnz {
403 let row = self.indices[0][k];
404 let col = self.indices[1][k];
405 let val = self.values[k];
406 for j in 0..n {
407 let idx = row * n + j;
408 result_data[idx] += val * dense_data[col * n + j];
409 }
410 }
411
412 Ok(result)
413 }
414
415 pub fn transpose(&self) -> Result<SparseTensor<T>> {
421 if self.ndim() != 2 {
422 return Err(CoreError::InvalidArgument {
423 reason: "transpose requires a 2D sparse tensor",
424 });
425 }
426 Ok(SparseTensor {
427 indices: vec![self.indices[1].clone(), self.indices[0].clone()],
428 values: self.values.clone(),
429 shape: vec![self.shape[1], self.shape[0]],
430 nnz: self.nnz,
431 })
432 }
433
434 pub fn coalesce(&mut self) {
444 if self.nnz == 0 {
445 return;
446 }
447 let ndim = self.shape.len();
448
449 let mut map: HashMap<Vec<usize>, T> = HashMap::new();
451 for k in 0..self.nnz {
452 let coord: Vec<usize> = (0..ndim).map(|d| self.indices[d][k]).collect();
453 let entry = map.entry(coord).or_insert(T::zero());
454 *entry += self.values[k];
455 }
456
457 let mut entries: Vec<(Vec<usize>, T)> =
459 map.into_iter().filter(|(_, v)| *v != T::zero()).collect();
460 entries.sort_by(|(a, _), (b, _)| a.cmp(b));
461
462 let new_nnz = entries.len();
464 let mut new_indices: Vec<Vec<usize>> = vec![Vec::with_capacity(new_nnz); ndim];
465 let mut new_values = Vec::with_capacity(new_nnz);
466 for (coord, val) in entries {
467 for (dim, &c) in coord.iter().enumerate() {
468 new_indices[dim].push(c);
469 }
470 new_values.push(val);
471 }
472
473 self.indices = new_indices;
474 self.values = new_values;
475 self.nnz = new_nnz;
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482
483 #[test]
484 fn test_sparse_from_dense_roundtrip() {
485 let dense = Tensor::from_vec(vec![0.0_f64, 1.0, 0.0, 2.0, 0.0, 3.0], vec![2, 3]).unwrap();
487 let sparse = SparseTensor::from_dense(&dense);
488 assert_eq!(sparse.nnz(), 3);
489 let recovered = sparse.to_dense().unwrap();
490 assert_eq!(dense, recovered);
491 }
492
493 #[test]
494 fn test_sparse_matmul() {
495 let a = SparseTensor::new(
498 vec![vec![0, 0, 1], vec![0, 1, 1]],
499 vec![1.0_f64, 2.0, 3.0],
500 vec![2, 2],
501 )
502 .unwrap();
503 let b = SparseTensor::new(
504 vec![vec![0, 1, 1], vec![0, 0, 1]],
505 vec![4.0_f64, 5.0, 6.0],
506 vec![2, 2],
507 )
508 .unwrap();
509 let mut c = a.sparse_matmul(&b).unwrap();
510 c.coalesce();
511 let dense_c = c.to_dense().unwrap();
512 let expected = Tensor::from_vec(vec![14.0, 12.0, 15.0, 18.0], vec![2, 2]).unwrap();
513 assert_eq!(dense_c, expected);
514 }
515
516 #[test]
517 fn test_sparse_dense_matmul() {
518 let a = SparseTensor::new(vec![vec![0, 1], vec![0, 1]], vec![1.0_f64, 2.0], vec![2, 2])
521 .unwrap();
522 let b = Tensor::from_vec(vec![3.0, 4.0, 5.0, 6.0], vec![2, 2]).unwrap();
523 let result = a.sparse_dense_matmul(&b).unwrap();
524 let expected = Tensor::from_vec(vec![3.0, 4.0, 10.0, 12.0], vec![2, 2]).unwrap();
525 assert_eq!(result, expected);
526 }
527
528 #[test]
529 fn test_sparse_add() {
530 let a = SparseTensor::new(vec![vec![0, 1], vec![0, 1]], vec![1.0_f64, 2.0], vec![2, 2])
531 .unwrap();
532 let b = SparseTensor::new(vec![vec![0, 1], vec![1, 0]], vec![3.0_f64, 4.0], vec![2, 2])
533 .unwrap();
534 let mut c = a.add(&b).unwrap();
535 c.coalesce();
536 let dense_c = c.to_dense().unwrap();
537 let expected = Tensor::from_vec(vec![1.0, 3.0, 4.0, 2.0], vec![2, 2]).unwrap();
539 assert_eq!(dense_c, expected);
540 }
541
542 #[test]
543 fn test_scalar_mul() {
544 let st = SparseTensor::new(vec![vec![0, 1], vec![0, 1]], vec![2.0_f64, 3.0], vec![2, 2])
545 .unwrap();
546 let scaled = st.scalar_mul(10.0);
547 assert_eq!(scaled.values(), &[20.0, 30.0]);
548 let dense = scaled.to_dense().unwrap();
549 let expected = Tensor::from_vec(vec![20.0, 0.0, 0.0, 30.0], vec![2, 2]).unwrap();
550 assert_eq!(dense, expected);
551 }
552
553 #[test]
554 fn test_coalesce() {
555 let mut st = SparseTensor::new(
557 vec![vec![0, 0, 1], vec![0, 0, 1]],
558 vec![1.0_f64, 2.0, 5.0],
559 vec![2, 2],
560 )
561 .unwrap();
562 assert_eq!(st.nnz(), 3);
563 st.coalesce();
564 assert_eq!(st.nnz(), 2);
565 let dense = st.to_dense().unwrap();
566 let expected = Tensor::from_vec(vec![3.0, 0.0, 0.0, 5.0], vec![2, 2]).unwrap();
567 assert_eq!(dense, expected);
568 }
569
570 #[test]
571 fn test_transpose() {
572 let st = SparseTensor::new(vec![vec![0], vec![1]], vec![3.0_f64], vec![2, 2]).unwrap();
575 let tr = st.transpose().unwrap();
576 assert_eq!(tr.shape(), &[2, 2]);
577 assert_eq!(tr.indices()[0], vec![1]); assert_eq!(tr.indices()[1], vec![0]); let dense = tr.to_dense().unwrap();
581 let expected = Tensor::from_vec(vec![0.0, 0.0, 3.0, 0.0], vec![2, 2]).unwrap();
582 assert_eq!(dense, expected);
583 }
584}