1use serde::{Deserialize, Serialize};
35use thiserror::Error;
36
37#[derive(Error, Debug, Clone, PartialEq)]
39pub enum SparseError {
40 #[error("Invalid sparse format conversion: {0} -> {1}")]
41 InvalidConversion(String, String),
42
43 #[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
44 ShapeMismatch {
45 expected: Vec<usize>,
46 actual: Vec<usize>,
47 },
48
49 #[error("Index out of bounds: {index:?} for shape {shape:?}")]
50 IndexOutOfBounds {
51 index: Vec<usize>,
52 shape: Vec<usize>,
53 },
54
55 #[error("Invalid sparse tensor: {0}")]
56 Invalid(String),
57
58 #[error("Unsupported operation: {0}")]
59 UnsupportedOperation(String),
60
61 #[error("Empty sparse tensor")]
62 Empty,
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
67pub enum SparseFormat {
68 CSR,
70 CSC,
72 COO,
74}
75
76impl SparseFormat {
77 pub fn name(&self) -> &'static str {
79 match self {
80 SparseFormat::CSR => "CSR",
81 SparseFormat::CSC => "CSC",
82 SparseFormat::COO => "COO",
83 }
84 }
85
86 pub fn is_compressed(&self) -> bool {
88 matches!(self, SparseFormat::CSR | SparseFormat::CSC)
89 }
90}
91
92#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
98pub struct SparseCSR {
99 pub shape: (usize, usize),
101 pub row_ptr: Vec<usize>,
103 pub col_indices: Vec<usize>,
105 pub values: Vec<f64>,
107}
108
109impl SparseCSR {
110 pub fn new(rows: usize, cols: usize) -> Self {
112 Self {
113 shape: (rows, cols),
114 row_ptr: vec![0; rows + 1],
115 col_indices: Vec::new(),
116 values: Vec::new(),
117 }
118 }
119
120 pub fn nnz(&self) -> usize {
122 self.values.len()
123 }
124
125 pub fn sparsity_ratio(&self) -> f64 {
127 let total = self.shape.0 * self.shape.1;
128 1.0 - (self.nnz() as f64 / total as f64)
129 }
130
131 pub fn row(&self, row_idx: usize) -> Result<Vec<(usize, f64)>, SparseError> {
133 if row_idx >= self.shape.0 {
134 return Err(SparseError::IndexOutOfBounds {
135 index: vec![row_idx],
136 shape: vec![self.shape.0],
137 });
138 }
139
140 let start = self.row_ptr[row_idx];
141 let end = self.row_ptr[row_idx + 1];
142
143 Ok((start..end)
144 .map(|i| (self.col_indices[i], self.values[i]))
145 .collect())
146 }
147
148 pub fn multiply_dense(&self, vec: &[f64]) -> Result<Vec<f64>, SparseError> {
150 if vec.len() != self.shape.1 {
151 return Err(SparseError::ShapeMismatch {
152 expected: vec![self.shape.1],
153 actual: vec![vec.len()],
154 });
155 }
156
157 let mut result = vec![0.0; self.shape.0];
158
159 for row_idx in 0..self.shape.0 {
160 let start = self.row_ptr[row_idx];
161 let end = self.row_ptr[row_idx + 1];
162
163 let mut sum = 0.0;
164 for i in start..end {
165 sum += self.values[i] * vec[self.col_indices[i]];
166 }
167 result[row_idx] = sum;
168 }
169
170 Ok(result)
171 }
172
173 pub fn transpose(&self) -> SparseCSC {
175 let mut csc = SparseCSC::new(self.shape.1, self.shape.0);
176 csc.col_ptr = vec![0; self.shape.1 + 1];
177
178 let mut counts = vec![0; self.shape.1];
180 for &col in &self.col_indices {
181 counts[col] += 1;
182 }
183
184 let mut sum = 0;
186 for i in 0..self.shape.1 {
187 csc.col_ptr[i] = sum;
188 sum += counts[i];
189 }
190 csc.col_ptr[self.shape.1] = sum;
191
192 csc.row_indices = vec![0; self.nnz()];
194 csc.values = vec![0.0; self.nnz()];
195 let mut positions = csc.col_ptr[..self.shape.1].to_vec();
196
197 for row in 0..self.shape.0 {
198 let start = self.row_ptr[row];
199 let end = self.row_ptr[row + 1];
200
201 for i in start..end {
202 let col = self.col_indices[i];
203 let pos = positions[col];
204 csc.row_indices[pos] = row;
205 csc.values[pos] = self.values[i];
206 positions[col] += 1;
207 }
208 }
209
210 csc
211 }
212
213 pub fn memory_bytes(&self) -> usize {
215 self.row_ptr.len() * std::mem::size_of::<usize>()
216 + self.col_indices.len() * std::mem::size_of::<usize>()
217 + self.values.len() * std::mem::size_of::<f64>()
218 }
219
220 pub fn validate(&self) -> Result<(), SparseError> {
222 if self.row_ptr.len() != self.shape.0 + 1 {
224 return Err(SparseError::Invalid(format!(
225 "Invalid row_ptr length: expected {}, got {}",
226 self.shape.0 + 1,
227 self.row_ptr.len()
228 )));
229 }
230
231 for i in 0..self.shape.0 {
233 if self.row_ptr[i] > self.row_ptr[i + 1] {
234 return Err(SparseError::Invalid(format!(
235 "Non-monotonic row_ptr at index {}",
236 i
237 )));
238 }
239 }
240
241 if self.row_ptr[self.shape.0] != self.nnz() {
243 return Err(SparseError::Invalid(format!(
244 "Last row_ptr {} doesn't match nnz {}",
245 self.row_ptr[self.shape.0],
246 self.nnz()
247 )));
248 }
249
250 for &col in &self.col_indices {
252 if col >= self.shape.1 {
253 return Err(SparseError::IndexOutOfBounds {
254 index: vec![0, col],
255 shape: vec![self.shape.0, self.shape.1],
256 });
257 }
258 }
259
260 Ok(())
261 }
262}
263
264#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
270pub struct SparseCSC {
271 pub shape: (usize, usize),
273 pub col_ptr: Vec<usize>,
275 pub row_indices: Vec<usize>,
277 pub values: Vec<f64>,
279}
280
281impl SparseCSC {
282 pub fn new(rows: usize, cols: usize) -> Self {
284 Self {
285 shape: (rows, cols),
286 col_ptr: vec![0; cols + 1],
287 row_indices: Vec::new(),
288 values: Vec::new(),
289 }
290 }
291
292 pub fn nnz(&self) -> usize {
294 self.values.len()
295 }
296
297 pub fn sparsity_ratio(&self) -> f64 {
299 let total = self.shape.0 * self.shape.1;
300 1.0 - (self.nnz() as f64 / total as f64)
301 }
302
303 pub fn column(&self, col_idx: usize) -> Result<Vec<(usize, f64)>, SparseError> {
305 if col_idx >= self.shape.1 {
306 return Err(SparseError::IndexOutOfBounds {
307 index: vec![col_idx],
308 shape: vec![self.shape.1],
309 });
310 }
311
312 let start = self.col_ptr[col_idx];
313 let end = self.col_ptr[col_idx + 1];
314
315 Ok((start..end)
316 .map(|i| (self.row_indices[i], self.values[i]))
317 .collect())
318 }
319
320 pub fn transpose(&self) -> SparseCSR {
322 let mut csr = SparseCSR::new(self.shape.1, self.shape.0);
323 csr.row_ptr = self.col_ptr.clone();
324 csr.col_indices = self.row_indices.clone();
325 csr.values = self.values.clone();
326 csr
327 }
328}
329
330#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
336pub struct SparseCOO {
337 pub shape: (usize, usize),
339 pub row_indices: Vec<usize>,
341 pub col_indices: Vec<usize>,
343 pub values: Vec<f64>,
345}
346
347impl SparseCOO {
348 pub fn new(rows: usize, cols: usize) -> Self {
350 Self {
351 shape: (rows, cols),
352 row_indices: Vec::new(),
353 col_indices: Vec::new(),
354 values: Vec::new(),
355 }
356 }
357
358 pub fn add_entry(&mut self, row: usize, col: usize, value: f64) -> Result<(), SparseError> {
360 if row >= self.shape.0 || col >= self.shape.1 {
361 return Err(SparseError::IndexOutOfBounds {
362 index: vec![row, col],
363 shape: vec![self.shape.0, self.shape.1],
364 });
365 }
366
367 self.row_indices.push(row);
368 self.col_indices.push(col);
369 self.values.push(value);
370
371 Ok(())
372 }
373
374 pub fn nnz(&self) -> usize {
376 self.values.len()
377 }
378
379 pub fn sparsity_ratio(&self) -> f64 {
381 let total = self.shape.0 * self.shape.1;
382 1.0 - (self.nnz() as f64 / total as f64)
383 }
384
385 pub fn to_csr(&self) -> SparseCSR {
387 let mut csr = SparseCSR::new(self.shape.0, self.shape.1);
388
389 let mut entries: Vec<_> = (0..self.nnz())
391 .map(|i| (self.row_indices[i], self.col_indices[i], self.values[i]))
392 .collect();
393 entries.sort_by_key(|(r, c, _)| (*r, *c));
394
395 csr.row_ptr = vec![0; self.shape.0 + 1];
397 csr.col_indices = Vec::with_capacity(entries.len());
398 csr.values = Vec::with_capacity(entries.len());
399
400 let mut current_row = 0;
401 for (row, col, val) in entries {
402 while current_row < row {
403 current_row += 1;
404 csr.row_ptr[current_row] = csr.col_indices.len();
405 }
406 csr.col_indices.push(col);
407 csr.values.push(val);
408 }
409
410 for i in current_row + 1..=self.shape.0 {
412 csr.row_ptr[i] = csr.col_indices.len();
413 }
414
415 csr
416 }
417
418 pub fn to_csc(&self) -> SparseCSC {
420 let mut csc = SparseCSC::new(self.shape.0, self.shape.1);
421
422 let mut entries: Vec<_> = (0..self.nnz())
424 .map(|i| (self.col_indices[i], self.row_indices[i], self.values[i]))
425 .collect();
426 entries.sort_by_key(|(c, r, _)| (*c, *r));
427
428 csc.col_ptr = vec![0; self.shape.1 + 1];
430 csc.row_indices = Vec::with_capacity(entries.len());
431 csc.values = Vec::with_capacity(entries.len());
432
433 let mut current_col = 0;
434 for (col, row, val) in entries {
435 while current_col < col {
436 current_col += 1;
437 csc.col_ptr[current_col] = csc.row_indices.len();
438 }
439 csc.row_indices.push(row);
440 csc.values.push(val);
441 }
442
443 for i in current_col + 1..=self.shape.1 {
445 csc.col_ptr[i] = csc.row_indices.len();
446 }
447
448 csc
449 }
450}
451
452#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
454pub enum SparseTensor {
455 CSR(SparseCSR),
457 CSC(SparseCSC),
459 COO(SparseCOO),
461}
462
463impl SparseTensor {
464 pub fn builder(shape: Vec<usize>, format: SparseFormat) -> SparseTensorBuilder {
466 SparseTensorBuilder::new(shape, format)
467 }
468
469 pub fn format(&self) -> SparseFormat {
471 match self {
472 SparseTensor::CSR(_) => SparseFormat::CSR,
473 SparseTensor::CSC(_) => SparseFormat::CSC,
474 SparseTensor::COO(_) => SparseFormat::COO,
475 }
476 }
477
478 pub fn shape(&self) -> Vec<usize> {
480 match self {
481 SparseTensor::CSR(m) => vec![m.shape.0, m.shape.1],
482 SparseTensor::CSC(m) => vec![m.shape.0, m.shape.1],
483 SparseTensor::COO(m) => vec![m.shape.0, m.shape.1],
484 }
485 }
486
487 pub fn nnz(&self) -> usize {
489 match self {
490 SparseTensor::CSR(m) => m.nnz(),
491 SparseTensor::CSC(m) => m.nnz(),
492 SparseTensor::COO(m) => m.nnz(),
493 }
494 }
495
496 pub fn sparsity_ratio(&self) -> f64 {
498 match self {
499 SparseTensor::CSR(m) => m.sparsity_ratio(),
500 SparseTensor::CSC(m) => m.sparsity_ratio(),
501 SparseTensor::COO(m) => m.sparsity_ratio(),
502 }
503 }
504
505 pub fn to_csr(&self) -> Result<SparseTensor, SparseError> {
507 match self {
508 SparseTensor::CSR(_) => Ok(self.clone()),
509 SparseTensor::CSC(m) => Ok(SparseTensor::CSR(m.transpose())),
510 SparseTensor::COO(m) => Ok(SparseTensor::CSR(m.to_csr())),
511 }
512 }
513
514 pub fn to_csc(&self) -> Result<SparseTensor, SparseError> {
516 match self {
517 SparseTensor::CSR(m) => Ok(SparseTensor::CSC(m.transpose())),
518 SparseTensor::CSC(_) => Ok(self.clone()),
519 SparseTensor::COO(m) => Ok(SparseTensor::CSC(m.to_csc())),
520 }
521 }
522
523 pub fn to_coo(&self) -> Result<SparseTensor, SparseError> {
525 match self {
526 SparseTensor::COO(_) => Ok(self.clone()),
527 SparseTensor::CSR(m) => {
528 let mut coo = SparseCOO::new(m.shape.0, m.shape.1);
529 for row in 0..m.shape.0 {
530 let start = m.row_ptr[row];
531 let end = m.row_ptr[row + 1];
532 for i in start..end {
533 coo.add_entry(row, m.col_indices[i], m.values[i])?;
534 }
535 }
536 Ok(SparseTensor::COO(coo))
537 }
538 SparseTensor::CSC(m) => {
539 let mut coo = SparseCOO::new(m.shape.0, m.shape.1);
540 for col in 0..m.shape.1 {
541 let start = m.col_ptr[col];
542 let end = m.col_ptr[col + 1];
543 for i in start..end {
544 coo.add_entry(m.row_indices[i], col, m.values[i])?;
545 }
546 }
547 Ok(SparseTensor::COO(coo))
548 }
549 }
550 }
551
552 pub fn memory_bytes(&self) -> usize {
554 match self {
555 SparseTensor::CSR(m) => m.memory_bytes(),
556 SparseTensor::CSC(m) => {
557 m.col_ptr.len() * std::mem::size_of::<usize>()
558 + m.row_indices.len() * std::mem::size_of::<usize>()
559 + m.values.len() * std::mem::size_of::<f64>()
560 }
561 SparseTensor::COO(m) => {
562 (m.row_indices.len() + m.col_indices.len()) * std::mem::size_of::<usize>()
563 + m.values.len() * std::mem::size_of::<f64>()
564 }
565 }
566 }
567}
568
569pub struct SparseTensorBuilder {
571 shape: Vec<usize>,
572 format: SparseFormat,
573 entries: Vec<(Vec<usize>, f64)>,
574}
575
576impl SparseTensorBuilder {
577 pub fn new(shape: Vec<usize>, format: SparseFormat) -> Self {
579 Self {
580 shape,
581 format,
582 entries: Vec::new(),
583 }
584 }
585
586 pub fn add_entry(&mut self, indices: Vec<usize>, value: f64) -> Result<(), SparseError> {
588 if indices.len() != self.shape.len() {
589 return Err(SparseError::ShapeMismatch {
590 expected: vec![self.shape.len()],
591 actual: vec![indices.len()],
592 });
593 }
594
595 for (i, &idx) in indices.iter().enumerate() {
596 if idx >= self.shape[i] {
597 return Err(SparseError::IndexOutOfBounds {
598 index: indices.clone(),
599 shape: self.shape.clone(),
600 });
601 }
602 }
603
604 self.entries.push((indices, value));
605 Ok(())
606 }
607
608 pub fn build(self) -> Result<SparseTensor, SparseError> {
610 if self.shape.len() != 2 {
612 return Err(SparseError::UnsupportedOperation(format!(
613 "Only 2D sparse tensors are supported, got shape {:?}",
614 self.shape
615 )));
616 }
617
618 let rows = self.shape[0];
619 let cols = self.shape[1];
620
621 let mut coo = SparseCOO::new(rows, cols);
623 for (indices, value) in self.entries {
624 coo.add_entry(indices[0], indices[1], value)?;
625 }
626
627 match self.format {
629 SparseFormat::COO => Ok(SparseTensor::COO(coo)),
630 SparseFormat::CSR => Ok(SparseTensor::CSR(coo.to_csr())),
631 SparseFormat::CSC => Ok(SparseTensor::CSC(coo.to_csc())),
632 }
633 }
634}
635
636pub fn detect_sparsity(data: &[f64], threshold: f64) -> (usize, f64) {
638 let total = data.len();
639 let zeros = data.iter().filter(|&&x| x.abs() < threshold).count();
640 let sparsity = zeros as f64 / total as f64;
641 (zeros, sparsity)
642}
643
644pub fn to_sparse_if_beneficial(
646 data: &[f64],
647 shape: Vec<usize>,
648 threshold: f64,
649 min_sparsity: f64,
650) -> Result<Option<SparseTensor>, SparseError> {
651 let (_, sparsity) = detect_sparsity(data, threshold);
652
653 if sparsity < min_sparsity {
654 return Ok(None);
655 }
656
657 let mut builder = SparseTensor::builder(shape.clone(), SparseFormat::CSR);
659
660 if shape.len() == 2 {
661 let cols = shape[1];
662 for (i, &val) in data.iter().enumerate() {
663 if val.abs() >= threshold {
664 let row = i / cols;
665 let col = i % cols;
666 builder.add_entry(vec![row, col], val)?;
667 }
668 }
669 }
670
671 Ok(Some(builder.build()?))
672}
673
674#[cfg(test)]
675mod tests {
676 use super::*;
677
678 #[test]
679 fn test_sparse_format() {
680 assert_eq!(SparseFormat::CSR.name(), "CSR");
681 assert!(SparseFormat::CSR.is_compressed());
682 assert!(!SparseFormat::COO.is_compressed());
683 }
684
685 #[test]
686 fn test_sparse_coo_creation() {
687 let mut coo = SparseCOO::new(3, 3);
688 assert_eq!(coo.shape, (3, 3));
689 assert_eq!(coo.nnz(), 0);
690
691 coo.add_entry(0, 1, 5.0).unwrap();
692 coo.add_entry(1, 2, 3.0).unwrap();
693 assert_eq!(coo.nnz(), 2);
694 }
695
696 #[test]
697 fn test_sparse_coo_to_csr() {
698 let mut coo = SparseCOO::new(3, 3);
699 coo.add_entry(0, 0, 1.0).unwrap();
700 coo.add_entry(0, 2, 2.0).unwrap();
701 coo.add_entry(2, 1, 3.0).unwrap();
702
703 let csr = coo.to_csr();
704 assert_eq!(csr.shape, (3, 3));
705 assert_eq!(csr.nnz(), 3);
706 assert!(csr.validate().is_ok());
707 }
708
709 #[test]
710 fn test_sparse_csr_multiply_dense() {
711 let mut coo = SparseCOO::new(2, 3);
712 coo.add_entry(0, 0, 1.0).unwrap();
713 coo.add_entry(0, 2, 2.0).unwrap();
714 coo.add_entry(1, 1, 3.0).unwrap();
715
716 let csr = coo.to_csr();
717 let vec = vec![1.0, 2.0, 3.0];
718 let result = csr.multiply_dense(&vec).unwrap();
719
720 assert_eq!(result.len(), 2);
721 assert!((result[0] - 7.0).abs() < 1e-10); assert!((result[1] - 6.0).abs() < 1e-10); }
724
725 #[test]
726 fn test_sparse_csr_row_access() {
727 let mut coo = SparseCOO::new(3, 3);
728 coo.add_entry(0, 0, 1.0).unwrap();
729 coo.add_entry(0, 2, 2.0).unwrap();
730 coo.add_entry(1, 1, 3.0).unwrap();
731
732 let csr = coo.to_csr();
733 let row0 = csr.row(0).unwrap();
734 assert_eq!(row0.len(), 2);
735 assert_eq!(row0[0], (0, 1.0));
736 assert_eq!(row0[1], (2, 2.0));
737
738 let row1 = csr.row(1).unwrap();
739 assert_eq!(row1.len(), 1);
740 assert_eq!(row1[0], (1, 3.0));
741 }
742
743 #[test]
744 fn test_sparse_csr_transpose() {
745 let mut coo = SparseCOO::new(2, 3);
746 coo.add_entry(0, 0, 1.0).unwrap();
747 coo.add_entry(0, 2, 2.0).unwrap();
748 coo.add_entry(1, 1, 3.0).unwrap();
749
750 let csr = coo.to_csr();
751 let csc = csr.transpose();
752
753 assert_eq!(csc.shape, (3, 2));
754 assert_eq!(csc.nnz(), 3);
755 }
756
757 #[test]
758 fn test_sparsity_ratio() {
759 let mut coo = SparseCOO::new(10, 10);
760 coo.add_entry(0, 0, 1.0).unwrap();
761 coo.add_entry(5, 5, 2.0).unwrap();
762
763 let sparsity = coo.sparsity_ratio();
764 assert!((sparsity - 0.98).abs() < 0.01); }
766
767 #[test]
768 fn test_sparse_tensor_builder() {
769 let mut builder = SparseTensor::builder(vec![3, 3], SparseFormat::CSR);
770 builder.add_entry(vec![0, 0], 1.0).unwrap();
771 builder.add_entry(vec![1, 2], 2.0).unwrap();
772
773 let sparse = builder.build().unwrap();
774 assert_eq!(sparse.format(), SparseFormat::CSR);
775 assert_eq!(sparse.nnz(), 2);
776 }
777
778 #[test]
779 fn test_sparse_tensor_conversion() {
780 let mut builder = SparseTensor::builder(vec![3, 3], SparseFormat::COO);
781 builder.add_entry(vec![0, 0], 1.0).unwrap();
782 builder.add_entry(vec![1, 2], 2.0).unwrap();
783
784 let coo = builder.build().unwrap();
785 let csr = coo.to_csr().unwrap();
786 let csc = csr.to_csc().unwrap();
787
788 assert_eq!(coo.nnz(), 2);
789 assert_eq!(csr.nnz(), 2);
790 assert_eq!(csc.nnz(), 2);
791 }
792
793 #[test]
794 fn test_detect_sparsity() {
795 let data = vec![0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0];
796 let (zeros, sparsity) = detect_sparsity(&data, 1e-10);
797
798 assert_eq!(zeros, 6);
799 assert!((sparsity - 0.666).abs() < 0.01);
800 }
801
802 #[test]
803 fn test_to_sparse_if_beneficial() {
804 let data = vec![0.0, 1.0, 0.0, 0.0, 2.0, 0.0];
805 let shape = vec![2, 3];
806
807 let sparse = to_sparse_if_beneficial(&data, shape, 1e-10, 0.5).unwrap();
808 assert!(sparse.is_some());
809
810 let sparse = sparse.unwrap();
811 assert_eq!(sparse.nnz(), 2);
812 assert!(sparse.sparsity_ratio() > 0.5);
813 }
814
815 #[test]
816 fn test_sparse_csr_validation() {
817 let csr = SparseCSR {
818 shape: (3, 3),
819 row_ptr: vec![0, 2, 3, 3],
820 col_indices: vec![0, 2, 1],
821 values: vec![1.0, 2.0, 3.0],
822 };
823
824 assert!(csr.validate().is_ok());
825 }
826
827 #[test]
828 fn test_sparse_memory_usage() {
829 let mut builder = SparseTensor::builder(vec![100, 100], SparseFormat::CSR);
830 builder.add_entry(vec![0, 0], 1.0).unwrap();
831 builder.add_entry(vec![50, 50], 2.0).unwrap();
832
833 let sparse = builder.build().unwrap();
834 let memory = sparse.memory_bytes();
835
836 let dense_memory = 100 * 100 * std::mem::size_of::<f64>();
838 assert!(memory < dense_memory / 10);
839 }
840}