1use crate::{BackendResult, Device};
14use std::collections::HashMap;
15use torsh_core::error::TorshError;
16
17#[cfg(not(feature = "std"))]
18use alloc::{string::String, vec::Vec};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum SparseFormat {
23 Coo,
25 Csr,
27 Csc,
29 Bsr,
31 Dense,
33}
34
35#[derive(Debug, Clone)]
37pub struct SparseMatrix<T> {
38 pub format: SparseFormat,
40 pub rows: usize,
42 pub cols: usize,
44 pub nnz: usize,
46 pub values: Vec<T>,
48 pub row_indices: Vec<usize>,
50 pub col_indices: Vec<usize>,
52 pub block_size: Option<(usize, usize)>,
54}
55
56impl<T> Default for SparseMatrix<T> {
57 fn default() -> Self {
58 Self {
59 format: SparseFormat::Coo,
60 rows: 0,
61 cols: 0,
62 nnz: 0,
63 values: Vec::new(),
64 row_indices: Vec::new(),
65 col_indices: Vec::new(),
66 block_size: None,
67 }
68 }
69}
70
71impl<T: Clone + Default + PartialEq> SparseMatrix<T> {
72 pub fn new_coo(rows: usize, cols: usize) -> Self {
74 Self {
75 format: SparseFormat::Coo,
76 rows,
77 cols,
78 nnz: 0,
79 values: Vec::new(),
80 row_indices: Vec::new(),
81 col_indices: Vec::new(),
82 block_size: None,
83 }
84 }
85
86 pub fn new_csr(rows: usize, cols: usize) -> Self {
88 Self {
89 format: SparseFormat::Csr,
90 rows,
91 cols,
92 nnz: 0,
93 values: Vec::new(),
94 row_indices: Vec::with_capacity(rows + 1), col_indices: Vec::new(),
96 block_size: None,
97 }
98 }
99
100 pub fn new_csc(rows: usize, cols: usize) -> Self {
102 Self {
103 format: SparseFormat::Csc,
104 rows,
105 cols,
106 nnz: 0,
107 values: Vec::new(),
108 row_indices: Vec::new(),
109 col_indices: Vec::with_capacity(cols + 1), block_size: None,
111 }
112 }
113
114 pub fn insert_coo(&mut self, row: usize, col: usize, value: T) -> BackendResult<()> {
116 if self.format != SparseFormat::Coo {
117 return Err(TorshError::ComputeError(
118 "Matrix is not in COO format".to_string(),
119 ));
120 }
121
122 if row >= self.rows || col >= self.cols {
123 return Err(TorshError::ComputeError("Index out of bounds".to_string()));
124 }
125
126 self.row_indices.push(row);
128 self.col_indices.push(col);
129 self.values.push(value);
130 self.nnz += 1;
131
132 Ok(())
133 }
134
135 pub fn to_csr(&self) -> BackendResult<SparseMatrix<T>> {
137 if self.format != SparseFormat::Coo {
138 return Err(TorshError::ComputeError(
139 "Source matrix must be in COO format".to_string(),
140 ));
141 }
142
143 let mut csr = SparseMatrix::new_csr(self.rows, self.cols);
144 csr.nnz = self.nnz;
145
146 if self.nnz == 0 {
147 csr.row_indices = vec![0; self.rows + 1];
149 return Ok(csr);
150 }
151
152 let mut row_counts = vec![0; self.rows];
154 for &row in &self.row_indices {
155 row_counts[row] += 1;
156 }
157
158 csr.row_indices.push(0);
160 for count in row_counts {
161 let last = *csr
162 .row_indices
163 .last()
164 .expect("row_indices should not be empty after initial push");
165 csr.row_indices.push(last + count);
166 }
167
168 let mut triplets: Vec<(usize, usize, T)> = self
170 .row_indices
171 .iter()
172 .zip(self.col_indices.iter())
173 .zip(self.values.iter())
174 .map(|((&r, &c), v)| (r, c, v.clone()))
175 .collect();
176
177 triplets.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
178
179 csr.values.reserve(self.nnz);
181 csr.col_indices.reserve(self.nnz);
182
183 for (_, col, value) in triplets {
184 csr.col_indices.push(col);
185 csr.values.push(value);
186 }
187
188 Ok(csr)
189 }
190
191 pub fn to_csc(&self) -> BackendResult<SparseMatrix<T>> {
193 if self.format != SparseFormat::Coo {
194 return Err(TorshError::ComputeError(
195 "Source matrix must be in COO format".to_string(),
196 ));
197 }
198
199 let mut csc = SparseMatrix::new_csc(self.rows, self.cols);
200 csc.nnz = self.nnz;
201
202 if self.nnz == 0 {
203 csc.col_indices = vec![0; self.cols + 1];
205 return Ok(csc);
206 }
207
208 let mut col_counts = vec![0; self.cols];
210 for &col in &self.col_indices {
211 col_counts[col] += 1;
212 }
213
214 csc.col_indices.push(0);
216 for count in col_counts {
217 let last = *csc
218 .col_indices
219 .last()
220 .expect("col_indices should not be empty after initial push");
221 csc.col_indices.push(last + count);
222 }
223
224 let mut triplets: Vec<(usize, usize, T)> = self
226 .row_indices
227 .iter()
228 .zip(self.col_indices.iter())
229 .zip(self.values.iter())
230 .map(|((&r, &c), v)| (r, c, v.clone()))
231 .collect();
232
233 triplets.sort_by(|a, b| a.1.cmp(&b.1).then(a.0.cmp(&b.0)));
234
235 csc.values.reserve(self.nnz);
237 csc.row_indices.reserve(self.nnz);
238
239 for (row, _, value) in triplets {
240 csc.row_indices.push(row);
241 csc.values.push(value);
242 }
243
244 Ok(csc)
245 }
246
247 pub fn sparsity_ratio(&self) -> f64 {
249 if self.rows == 0 || self.cols == 0 {
250 return 0.0;
251 }
252 self.nnz as f64 / (self.rows * self.cols) as f64
253 }
254
255 pub fn is_sparse(&self) -> bool {
257 self.sparsity_ratio() < 0.5
258 }
259}
260
261pub trait SparseOps<T> {
263 fn spmv(&self, matrix: &SparseMatrix<T>, x: &[T], y: &mut [T]) -> BackendResult<()>;
265
266 fn spmm(&self, a: &SparseMatrix<T>, b: &SparseMatrix<T>) -> BackendResult<SparseMatrix<T>>;
268
269 fn sparse_add(
271 &self,
272 a: &SparseMatrix<T>,
273 b: &SparseMatrix<T>,
274 ) -> BackendResult<SparseMatrix<T>>;
275
276 fn to_dense(&self, matrix: &SparseMatrix<T>) -> BackendResult<Vec<T>>;
278
279 fn from_dense(
281 &self,
282 dense: &[T],
283 rows: usize,
284 cols: usize,
285 threshold: T,
286 ) -> BackendResult<SparseMatrix<T>>;
287
288 fn transpose(&self, matrix: &SparseMatrix<T>) -> BackendResult<SparseMatrix<T>>;
290}
291
292#[derive(Debug)]
294pub struct DefaultSparseOps {
295 #[allow(dead_code)]
297 device: Device,
298 optimization_hints: SparseOptimizationHints,
300}
301
302impl DefaultSparseOps {
303 pub fn new(device: Device) -> Self {
305 Self {
306 device,
307 optimization_hints: SparseOptimizationHints::default(),
308 }
309 }
310
311 pub fn with_hints(mut self, hints: SparseOptimizationHints) -> Self {
313 self.optimization_hints = hints;
314 self
315 }
316}
317
318impl SparseOps<f32> for DefaultSparseOps {
319 fn spmv(&self, matrix: &SparseMatrix<f32>, x: &[f32], y: &mut [f32]) -> BackendResult<()> {
320 if x.len() != matrix.cols || y.len() != matrix.rows {
321 return Err(TorshError::ComputeError("Dimension mismatch".to_string()));
322 }
323
324 y.fill(0.0);
326
327 match matrix.format {
328 SparseFormat::Csr => self.spmv_csr(matrix, x, y),
329 SparseFormat::Coo => self.spmv_coo(matrix, x, y),
330 SparseFormat::Csc => self.spmv_csc(matrix, x, y),
331 _ => Err(TorshError::ComputeError(
332 "Unsupported sparse format for SpMV".to_string(),
333 )),
334 }
335 }
336
337 fn spmm(
338 &self,
339 a: &SparseMatrix<f32>,
340 b: &SparseMatrix<f32>,
341 ) -> BackendResult<SparseMatrix<f32>> {
342 if a.cols != b.rows {
343 return Err(TorshError::ComputeError(
344 "Matrix dimensions incompatible for multiplication".to_string(),
345 ));
346 }
347
348 let a_csr = if a.format == SparseFormat::Csr {
350 a.clone()
351 } else {
352 a.to_csr()?
353 };
354
355 let b_csr = if b.format == SparseFormat::Csr {
356 b.clone()
357 } else {
358 b.to_csr()?
359 };
360
361 self.spmm_csr_csr(&a_csr, &b_csr)
362 }
363
364 fn sparse_add(
365 &self,
366 a: &SparseMatrix<f32>,
367 b: &SparseMatrix<f32>,
368 ) -> BackendResult<SparseMatrix<f32>> {
369 if a.rows != b.rows || a.cols != b.cols {
370 return Err(TorshError::ComputeError(
371 "Matrix dimensions must match for addition".to_string(),
372 ));
373 }
374
375 let a_coo = if a.format == SparseFormat::Coo {
377 a.clone()
378 } else {
379 return Err(TorshError::ComputeError(
381 "Sparse addition requires COO format".to_string(),
382 ));
383 };
384
385 let b_coo = if b.format == SparseFormat::Coo {
386 b.clone()
387 } else {
388 return Err(TorshError::ComputeError(
389 "Sparse addition requires COO format".to_string(),
390 ));
391 };
392
393 self.sparse_add_coo(&a_coo, &b_coo)
394 }
395
396 fn to_dense(&self, matrix: &SparseMatrix<f32>) -> BackendResult<Vec<f32>> {
397 let mut dense = vec![0.0; matrix.rows * matrix.cols];
398
399 match matrix.format {
400 SparseFormat::Coo => {
401 for i in 0..matrix.nnz {
402 let row = matrix.row_indices[i];
403 let col = matrix.col_indices[i];
404 let val = matrix.values[i];
405 dense[row * matrix.cols + col] = val;
406 }
407 }
408 SparseFormat::Csr => {
409 for row in 0..matrix.rows {
410 let start = matrix.row_indices[row];
411 let end = matrix.row_indices[row + 1];
412 for idx in start..end {
413 let col = matrix.col_indices[idx];
414 let val = matrix.values[idx];
415 dense[row * matrix.cols + col] = val;
416 }
417 }
418 }
419 SparseFormat::Csc => {
420 for col in 0..matrix.cols {
421 let start = matrix.col_indices[col];
422 let end = matrix.col_indices[col + 1];
423 for idx in start..end {
424 let row = matrix.row_indices[idx];
425 let val = matrix.values[idx];
426 dense[row * matrix.cols + col] = val;
427 }
428 }
429 }
430 _ => {
431 return Err(TorshError::ComputeError(
432 "Unsupported format for dense conversion".to_string(),
433 ))
434 }
435 }
436
437 Ok(dense)
438 }
439
440 fn from_dense(
441 &self,
442 dense: &[f32],
443 rows: usize,
444 cols: usize,
445 threshold: f32,
446 ) -> BackendResult<SparseMatrix<f32>> {
447 if dense.len() != rows * cols {
448 return Err(TorshError::ComputeError(
449 "Dense array size doesn't match dimensions".to_string(),
450 ));
451 }
452
453 let mut sparse = SparseMatrix::new_coo(rows, cols);
454
455 for row in 0..rows {
456 for col in 0..cols {
457 let val = dense[row * cols + col];
458 if val.abs() > threshold {
459 sparse.insert_coo(row, col, val)?;
460 }
461 }
462 }
463
464 Ok(sparse)
465 }
466
467 fn transpose(&self, matrix: &SparseMatrix<f32>) -> BackendResult<SparseMatrix<f32>> {
468 match matrix.format {
469 SparseFormat::Coo => {
470 let mut transposed = SparseMatrix::new_coo(matrix.cols, matrix.rows);
471 transposed.nnz = matrix.nnz;
472
473 transposed.row_indices = matrix.col_indices.clone();
475 transposed.col_indices = matrix.row_indices.clone();
476 transposed.values = matrix.values.clone();
477
478 Ok(transposed)
479 }
480 SparseFormat::Csr => {
481 let mut transposed = SparseMatrix::new_csc(matrix.cols, matrix.rows);
483 transposed.nnz = matrix.nnz;
484 transposed.values = matrix.values.clone();
485 transposed.row_indices = matrix.col_indices.clone();
486 transposed.col_indices = matrix.row_indices.clone();
487 Ok(transposed)
488 }
489 SparseFormat::Csc => {
490 let mut transposed = SparseMatrix::new_csr(matrix.cols, matrix.rows);
492 transposed.nnz = matrix.nnz;
493 transposed.values = matrix.values.clone();
494 transposed.row_indices = matrix.col_indices.clone();
495 transposed.col_indices = matrix.row_indices.clone();
496 Ok(transposed)
497 }
498 _ => Err(TorshError::ComputeError(
499 "Unsupported format for transpose".to_string(),
500 )),
501 }
502 }
503}
504
505impl DefaultSparseOps {
506 fn spmv_csr(&self, matrix: &SparseMatrix<f32>, x: &[f32], y: &mut [f32]) -> BackendResult<()> {
508 for row in 0..matrix.rows {
509 let start = matrix.row_indices[row];
510 let end = matrix.row_indices[row + 1];
511 let mut sum = 0.0;
512
513 for idx in start..end {
514 let col = matrix.col_indices[idx];
515 let val = matrix.values[idx];
516 sum += val * x[col];
517 }
518
519 y[row] = sum;
520 }
521 Ok(())
522 }
523
524 fn spmv_coo(&self, matrix: &SparseMatrix<f32>, x: &[f32], y: &mut [f32]) -> BackendResult<()> {
526 for i in 0..matrix.nnz {
527 let row = matrix.row_indices[i];
528 let col = matrix.col_indices[i];
529 let val = matrix.values[i];
530 y[row] += val * x[col];
531 }
532 Ok(())
533 }
534
535 fn spmv_csc(&self, matrix: &SparseMatrix<f32>, x: &[f32], y: &mut [f32]) -> BackendResult<()> {
537 for col in 0..matrix.cols {
538 let start = matrix.col_indices[col];
539 let end = matrix.col_indices[col + 1];
540 let x_val = x[col];
541
542 for idx in start..end {
543 let row = matrix.row_indices[idx];
544 let val = matrix.values[idx];
545 y[row] += val * x_val;
546 }
547 }
548 Ok(())
549 }
550
551 fn spmm_csr_csr(
553 &self,
554 a: &SparseMatrix<f32>,
555 b: &SparseMatrix<f32>,
556 ) -> BackendResult<SparseMatrix<f32>> {
557 let mut result = SparseMatrix::new_coo(a.rows, b.cols);
560
561 for row_a in 0..a.rows {
562 let start_a = a.row_indices[row_a];
563 let end_a = a.row_indices[row_a + 1];
564
565 for idx_a in start_a..end_a {
566 let col_a = a.col_indices[idx_a];
567 let val_a = a.values[idx_a];
568
569 let start_b = b.row_indices[col_a];
571 let end_b = b.row_indices[col_a + 1];
572
573 for idx_b in start_b..end_b {
574 let col_b = b.col_indices[idx_b];
575 let val_b = b.values[idx_b];
576
577 let product = val_a * val_b;
578 result.insert_coo(row_a, col_b, product)?;
579 }
580 }
581 }
582
583 Ok(result)
584 }
585
586 fn sparse_add_coo(
588 &self,
589 a: &SparseMatrix<f32>,
590 b: &SparseMatrix<f32>,
591 ) -> BackendResult<SparseMatrix<f32>> {
592 let mut result = SparseMatrix::new_coo(a.rows, a.cols);
593
594 let mut entries: HashMap<(usize, usize), f32> = HashMap::new();
596
597 for i in 0..a.nnz {
599 let key = (a.row_indices[i], a.col_indices[i]);
600 *entries.entry(key).or_insert(0.0) += a.values[i];
601 }
602
603 for i in 0..b.nnz {
605 let key = (b.row_indices[i], b.col_indices[i]);
606 *entries.entry(key).or_insert(0.0) += b.values[i];
607 }
608
609 for ((row, col), value) in entries {
611 if value != 0.0 {
612 result.insert_coo(row, col, value)?;
613 }
614 }
615
616 Ok(result)
617 }
618}
619
620impl<T: Clone + Default + PartialEq> SparseMatrix<T> {
621 pub fn to_bsr(&self, block_size: (usize, usize)) -> BackendResult<SparseMatrix<T>> {
623 if self.format != SparseFormat::Coo {
624 return Err(TorshError::ComputeError(
625 "Source matrix must be in COO format".to_string(),
626 ));
627 }
628
629 let (block_rows, block_cols) = block_size;
630 if block_rows == 0 || block_cols == 0 {
631 return Err(TorshError::ComputeError(
632 "Block size must be positive".to_string(),
633 ));
634 }
635
636 let num_block_rows = (self.rows + block_rows - 1) / block_rows;
638 let _num_block_cols = (self.cols + block_cols - 1) / block_cols;
639
640 let mut bsr = SparseMatrix {
641 format: SparseFormat::Bsr,
642 rows: self.rows,
643 cols: self.cols,
644 nnz: 0,
645 values: Vec::new(),
646 row_indices: vec![0; num_block_rows + 1], col_indices: Vec::new(), block_size: Some(block_size),
649 };
650
651 let mut blocks: HashMap<(usize, usize), Vec<T>> = HashMap::new();
653
654 for i in 0..self.nnz {
655 let row = self.row_indices[i];
656 let col = self.col_indices[i];
657 let val = self.values[i].clone();
658
659 let block_row = row / block_rows;
660 let block_col = col / block_cols;
661 let in_block_row = row % block_rows;
662 let in_block_col = col % block_cols;
663
664 let block_entry = blocks
665 .entry((block_row, block_col))
666 .or_insert_with(|| vec![T::default(); block_rows * block_cols]);
667 block_entry[in_block_row * block_cols + in_block_col] = val;
668 }
669
670 let mut sorted_blocks: Vec<_> = blocks.into_iter().collect();
672 sorted_blocks.sort_by_key(|&((br, bc), _)| (br, bc));
673
674 let mut current_block_row = 0;
675 for ((block_row, block_col), block_values) in sorted_blocks {
676 while current_block_row < block_row {
678 current_block_row += 1;
679 bsr.row_indices[current_block_row] = bsr.col_indices.len();
680 }
681
682 bsr.col_indices.push(block_col);
684 bsr.values.extend(block_values);
685 bsr.nnz += 1; }
687
688 let final_ptr = bsr.col_indices.len();
690 for i in (current_block_row + 1)..=num_block_rows {
691 bsr.row_indices[i] = final_ptr;
692 }
693
694 Ok(bsr)
695 }
696
697 pub fn optimize(&mut self) -> BackendResult<()> {
699 match self.format {
700 SparseFormat::Coo => {
701 let mut triplets: Vec<(usize, usize, T)> = (0..self.nnz)
703 .filter_map(|i| {
704 let val = &self.values[i];
705 if *val != T::default() {
706 Some((self.row_indices[i], self.col_indices[i], val.clone()))
707 } else {
708 None
709 }
710 })
711 .collect();
712
713 triplets.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
714
715 self.nnz = triplets.len();
717 self.row_indices.clear();
718 self.col_indices.clear();
719 self.values.clear();
720
721 for (row, col, val) in triplets {
722 self.row_indices.push(row);
723 self.col_indices.push(col);
724 self.values.push(val);
725 }
726 }
727 SparseFormat::Csr | SparseFormat::Csc => {
728 let mut new_values = Vec::new();
730 let mut new_col_indices = Vec::new();
731 let mut new_row_pointers = vec![0];
732
733 let num_rows = if self.format == SparseFormat::Csr {
734 self.rows
735 } else {
736 self.cols
737 };
738
739 for row in 0..num_rows {
740 let start = self.row_indices[row];
741 let end = self.row_indices[row + 1];
742
743 for idx in start..end {
744 if self.values[idx] != T::default() {
745 new_values.push(self.values[idx].clone());
746 new_col_indices.push(self.col_indices[idx]);
747 }
748 }
749 new_row_pointers.push(new_values.len());
750 }
751
752 self.values = new_values;
753 self.col_indices = new_col_indices;
754 self.row_indices = new_row_pointers;
755 self.nnz = self.values.len();
756 }
757 _ => {
758 return Err(TorshError::ComputeError(
759 "Optimization not supported for this format".to_string(),
760 ))
761 }
762 }
763
764 Ok(())
765 }
766
767 pub fn statistics(&self) -> SparseMatrixStatistics {
769 let mut max_row_nnz = 0;
770 let mut min_row_nnz = usize::MAX;
771 let mut row_nnz_variance = 0.0;
772
773 match self.format {
774 SparseFormat::Csr => {
775 let mut row_counts = Vec::new();
776 for row in 0..self.rows {
777 let count = self.row_indices[row + 1] - self.row_indices[row];
778 row_counts.push(count);
779 max_row_nnz = max_row_nnz.max(count);
780 min_row_nnz = min_row_nnz.min(count);
781 }
782
783 let mean = row_counts.iter().sum::<usize>() as f64 / row_counts.len() as f64;
784 row_nnz_variance = row_counts
785 .iter()
786 .map(|&x| (x as f64 - mean).powi(2))
787 .sum::<f64>()
788 / row_counts.len() as f64;
789 }
790 SparseFormat::Coo => {
791 let mut row_counts = vec![0; self.rows];
792 for &row in &self.row_indices {
793 row_counts[row] += 1;
794 }
795 max_row_nnz = *row_counts.iter().max().unwrap_or(&0);
796 min_row_nnz = *row_counts.iter().min().unwrap_or(&0);
797
798 let mean = self.nnz as f64 / self.rows as f64;
799 row_nnz_variance = row_counts
800 .iter()
801 .map(|&x| (x as f64 - mean).powi(2))
802 .sum::<f64>()
803 / self.rows as f64;
804 }
805 _ => {
806 min_row_nnz = if self.nnz == 0 { 0 } else { 1 };
808 }
809 }
810
811 SparseMatrixStatistics {
812 format: self.format,
813 rows: self.rows,
814 cols: self.cols,
815 nnz: self.nnz,
816 sparsity_ratio: self.sparsity_ratio(),
817 max_row_nnz,
818 min_row_nnz,
819 row_nnz_variance,
820 memory_usage: self.estimated_memory_usage(),
821 }
822 }
823
824 fn estimated_memory_usage(&self) -> usize {
826 SparseFormatConverter::estimate_memory_usage(self.rows, self.cols, self.nnz, self.format)
827 }
828}
829
830#[derive(Debug, Clone)]
832pub struct SparseOptimizationHints {
833 pub memory_efficient: bool,
835 pub use_parallel: bool,
837 pub expected_sparsity: f64,
839 pub block_size: Option<(usize, usize)>,
841 pub cache_block_size: usize,
843}
844
845impl Default for SparseOptimizationHints {
846 fn default() -> Self {
847 Self {
848 memory_efficient: true,
849 use_parallel: true,
850 expected_sparsity: 0.1, block_size: None,
852 cache_block_size: 64,
853 }
854 }
855}
856
857pub struct SparseFormatConverter;
859
860impl SparseFormatConverter {
861 pub fn choose_optimal_format<T>(
863 _matrix: &SparseMatrix<T>,
864 operation: SparseOperation,
865 ) -> SparseFormat {
866 match operation {
867 SparseOperation::SpMV => {
868 SparseFormat::Csr
870 }
871 SparseOperation::SpMM => {
872 SparseFormat::Csr
874 }
875 SparseOperation::Addition => {
876 SparseFormat::Coo
878 }
879 SparseOperation::Transpose => {
880 SparseFormat::Coo
882 }
883 SparseOperation::Iterative => {
884 SparseFormat::Csr
886 }
887 }
888 }
889
890 pub fn estimate_memory_usage(
892 rows: usize,
893 cols: usize,
894 nnz: usize,
895 format: SparseFormat,
896 ) -> usize {
897 match format {
898 SparseFormat::Coo => {
899 nnz * (std::mem::size_of::<usize>() * 2 + std::mem::size_of::<f32>())
901 }
902 SparseFormat::Csr => {
903 (rows + 1) * std::mem::size_of::<usize>()
905 + nnz * (std::mem::size_of::<usize>() + std::mem::size_of::<f32>())
906 }
907 SparseFormat::Csc => {
908 (cols + 1) * std::mem::size_of::<usize>()
910 + nnz * (std::mem::size_of::<usize>() + std::mem::size_of::<f32>())
911 }
912 SparseFormat::Dense => rows * cols * std::mem::size_of::<f32>(),
913 _ => nnz * std::mem::size_of::<f32>() * 3, }
915 }
916}
917
918#[derive(Debug, Clone, Copy, PartialEq, Eq)]
920pub enum SparseOperation {
921 SpMV,
923 SpMM,
925 Addition,
927 Transpose,
929 Iterative,
931}
932
933#[derive(Debug, Clone)]
935pub struct SparseMatrixStatistics {
936 pub format: SparseFormat,
938 pub rows: usize,
940 pub cols: usize,
942 pub nnz: usize,
944 pub sparsity_ratio: f64,
946 pub max_row_nnz: usize,
948 pub min_row_nnz: usize,
950 pub row_nnz_variance: f64,
952 pub memory_usage: usize,
954}
955
956impl SparseMatrixStatistics {
957 pub fn is_well_balanced(&self) -> bool {
959 if self.rows == 0 || self.nnz == 0 {
960 return true;
961 }
962
963 let avg_nnz_per_row = self.nnz as f64 / self.rows as f64;
964 let balance_ratio = self.max_row_nnz as f64 / avg_nnz_per_row.max(1.0);
965
966 balance_ratio < 3.0
968 }
969
970 pub fn recommended_operations(&self) -> Vec<&'static str> {
972 let mut recommendations = Vec::new();
973
974 if self.sparsity_ratio < 0.1 {
975 recommendations.push("Very sparse - excellent for sparse algorithms");
976 } else if self.sparsity_ratio > 0.5 {
977 recommendations.push("Dense - consider dense algorithms");
978 }
979
980 if !self.is_well_balanced() {
981 recommendations.push("Unbalanced structure - consider load balancing");
982 }
983
984 match self.format {
985 SparseFormat::Coo => {
986 recommendations.push("COO format - good for construction and element access")
987 }
988 SparseFormat::Csr => {
989 recommendations.push("CSR format - optimal for SpMV and most algorithms")
990 }
991 SparseFormat::Csc => recommendations.push("CSC format - good for transpose operations"),
992 SparseFormat::Bsr => {
993 recommendations.push("BSR format - optimal for block-structured sparsity")
994 }
995 SparseFormat::Dense => recommendations.push("Dense format - use dense linear algebra"),
996 }
997
998 recommendations
999 }
1000}
1001
1002pub trait HardwareSparseOps<T>: SparseOps<T> {
1004 fn acceleration_capabilities(&self) -> SparseAccelerationCapabilities;
1006
1007 fn batch_spmv(
1009 &self,
1010 matrices: &[&SparseMatrix<T>],
1011 vectors: &[&[T]],
1012 results: &mut [&mut [T]],
1013 ) -> BackendResult<()>;
1014
1015 fn fused_spmv_add(
1017 &self,
1018 matrix: &SparseMatrix<T>,
1019 x: &[T],
1020 y: &[T],
1021 result: &mut [T],
1022 alpha: T,
1023 beta: T,
1024 ) -> BackendResult<()>;
1025
1026 fn iterative_solve(
1028 &self,
1029 matrix: &SparseMatrix<T>,
1030 b: &[T],
1031 x0: &[T],
1032 method: IterativeMethod,
1033 tolerance: f64,
1034 max_iterations: usize,
1035 ) -> BackendResult<SolverResult<T>>;
1036}
1037
1038#[derive(Debug, Clone)]
1040pub struct SparseAccelerationCapabilities {
1041 pub simd_width: usize,
1043 pub gpu_acceleration: bool,
1045 pub specialized_hardware: bool,
1047 pub parallel_execution: bool,
1049 pub memory_bandwidth: f32,
1051}
1052
1053#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1055pub enum IterativeMethod {
1056 ConjugateGradient,
1058 BiCGStab,
1060 GMRES,
1062 Jacobi,
1064 GaussSeidel,
1066}
1067
1068#[derive(Debug, Clone)]
1070pub struct SolverResult<T> {
1071 pub solution: Vec<T>,
1073 pub iterations: usize,
1075 pub residual_norm: f64,
1077 pub converged: bool,
1079 pub error_message: Option<String>,
1081}
1082
1083#[derive(Debug)]
1085pub struct AdvancedSparseOps {
1086 base_ops: DefaultSparseOps,
1087 acceleration_caps: SparseAccelerationCapabilities,
1088 performance_cache: HashMap<String, f64>, }
1090
1091impl AdvancedSparseOps {
1092 pub fn new(device: Device) -> Self {
1094 let acceleration_caps = Self::detect_acceleration_capabilities(&device);
1095 let base_ops = DefaultSparseOps::new(device);
1096
1097 Self {
1098 base_ops,
1099 acceleration_caps,
1100 performance_cache: HashMap::new(),
1101 }
1102 }
1103
1104 fn detect_acceleration_capabilities(device: &Device) -> SparseAccelerationCapabilities {
1106 SparseAccelerationCapabilities {
1107 simd_width: if cfg!(target_arch = "x86_64") { 8 } else { 4 }, gpu_acceleration: device.device_type() != torsh_core::device::DeviceType::Cpu,
1109 specialized_hardware: false, parallel_execution: true,
1111 memory_bandwidth: if device.device_type() == torsh_core::device::DeviceType::Cpu {
1112 50.0
1113 } else {
1114 500.0
1115 },
1116 }
1117 }
1118
1119 pub fn optimized_spmv(
1121 &mut self,
1122 matrix: &SparseMatrix<f32>,
1123 x: &[f32],
1124 y: &mut [f32],
1125 ) -> BackendResult<()> {
1126 let _cache_key = format!(
1127 "spmv_{}_{}_{}_{}",
1128 matrix.format as u8, matrix.rows, matrix.cols, matrix.nnz
1129 );
1130
1131 if self.acceleration_caps.parallel_execution && matrix.nnz > 10000 {
1133 self.parallel_spmv(matrix, x, y)
1134 } else if self.acceleration_caps.simd_width > 1 {
1135 self.simd_spmv(matrix, x, y)
1136 } else {
1137 self.base_ops.spmv(matrix, x, y)
1138 }
1139 }
1140
1141 fn parallel_spmv(
1143 &self,
1144 matrix: &SparseMatrix<f32>,
1145 x: &[f32],
1146 y: &mut [f32],
1147 ) -> BackendResult<()> {
1148 match matrix.format {
1149 SparseFormat::Csr => {
1150 use scirs2_core::parallel_ops::*;
1152
1153 let row_chunks: Vec<_> = (0..matrix.rows).collect();
1155 let chunk_size = (matrix.rows + current_num_threads() - 1) / current_num_threads();
1156
1157 row_chunks.par_chunks(chunk_size).for_each(|chunk| {
1158 for &row in chunk {
1159 let start = matrix.row_indices[row];
1160 let end = matrix.row_indices[row + 1];
1161 let mut sum = 0.0;
1162
1163 for idx in start..end {
1164 let col = matrix.col_indices[idx];
1165 let val = matrix.values[idx];
1166 sum += val * x[col];
1167 }
1168
1169 unsafe {
1171 let y_ptr = y.as_ptr() as *mut f32;
1172 *y_ptr.add(row) = sum;
1173 }
1174 }
1175 });
1176
1177 Ok(())
1178 }
1179 _ => self.base_ops.spmv(matrix, x, y), }
1181 }
1182
1183 fn simd_spmv(&self, matrix: &SparseMatrix<f32>, x: &[f32], y: &mut [f32]) -> BackendResult<()> {
1185 self.base_ops.spmv(matrix, x, y)
1188 }
1189
1190 pub fn adaptive_format_conversion(
1192 &self,
1193 matrix: &SparseMatrix<f32>,
1194 target_operation: SparseOperation,
1195 ) -> BackendResult<SparseMatrix<f32>> {
1196 let stats = matrix.statistics();
1197
1198 let optimal_format = if stats.is_well_balanced() {
1199 match target_operation {
1200 SparseOperation::SpMV => SparseFormat::Csr,
1201 SparseOperation::SpMM => SparseFormat::Csr,
1202 SparseOperation::Addition => SparseFormat::Coo,
1203 SparseOperation::Transpose => SparseFormat::Coo,
1204 SparseOperation::Iterative => SparseFormat::Csr,
1205 }
1206 } else {
1207 match target_operation {
1209 SparseOperation::SpMV if stats.max_row_nnz > stats.nnz / 10 => SparseFormat::Coo, _ => SparseFormatConverter::choose_optimal_format(matrix, target_operation),
1211 }
1212 };
1213
1214 if matrix.format == optimal_format {
1215 Ok(matrix.clone())
1216 } else {
1217 match optimal_format {
1218 SparseFormat::Csr => matrix.to_csr(),
1219 SparseFormat::Csc => matrix.to_csc(),
1220 SparseFormat::Bsr => {
1221 let block_size = (8, 8); matrix.to_bsr(block_size)
1223 }
1224 _ => Ok(matrix.clone()),
1225 }
1226 }
1227 }
1228
1229 pub fn benchmark_operation(&mut self, operation: &str, matrix: &SparseMatrix<f32>) -> f64 {
1231 let cache_key = format!(
1232 "{}_{}_{}_{}",
1233 operation, matrix.rows, matrix.cols, matrix.nnz
1234 );
1235
1236 if let Some(&cached_time) = self.performance_cache.get(&cache_key) {
1237 return cached_time;
1238 }
1239
1240 let estimated_time = match operation {
1242 "spmv" => {
1243 (matrix.nnz as f64 * 2.0) / (self.acceleration_caps.memory_bandwidth as f64 * 1e9)
1244 }
1245 "spmm" => {
1246 (matrix.nnz as f64 * matrix.cols as f64 * 2.0)
1247 / (self.acceleration_caps.memory_bandwidth as f64 * 1e9)
1248 }
1249 _ => 0.001, };
1251
1252 self.performance_cache.insert(cache_key, estimated_time);
1253 estimated_time
1254 }
1255}
1256
1257pub struct SparseLinAlg;
1259
1260impl SparseLinAlg {
1261 pub fn frobenius_norm<T>(matrix: &SparseMatrix<T>) -> f64
1263 where
1264 T: Clone + Default + PartialEq,
1265 f64: From<T>,
1266 {
1267 let mut sum = 0.0;
1268 for value in &matrix.values {
1269 let val: f64 = value.clone().into();
1270 sum += val * val;
1271 }
1272 sum.sqrt()
1273 }
1274
1275 pub fn is_symmetric(matrix: &SparseMatrix<f32>, tolerance: f32) -> BackendResult<bool> {
1277 if matrix.rows != matrix.cols {
1278 return Ok(false);
1279 }
1280
1281 let coo = if matrix.format == SparseFormat::Coo {
1283 matrix.clone()
1284 } else {
1285 return Err(TorshError::ComputeError(
1286 "Symmetry check requires COO format".to_string(),
1287 ));
1288 };
1289
1290 let mut entries: HashMap<(usize, usize), f32> = HashMap::new();
1292 for i in 0..coo.nnz {
1293 entries.insert((coo.row_indices[i], coo.col_indices[i]), coo.values[i]);
1294 }
1295
1296 for ((row, col), &value) in &entries {
1298 if let Some(&transpose_value) = entries.get(&(*col, *row)) {
1299 if (value - transpose_value).abs() > tolerance {
1300 return Ok(false);
1301 }
1302 } else if value.abs() > tolerance {
1303 return Ok(false);
1305 }
1306 }
1307
1308 Ok(true)
1309 }
1310
1311 pub fn diagonal<T: Clone + Default>(matrix: &SparseMatrix<T>) -> Vec<T> {
1313 let mut diag = vec![T::default(); matrix.rows.min(matrix.cols)];
1314
1315 match matrix.format {
1316 SparseFormat::Coo => {
1317 for i in 0..matrix.nnz {
1318 let row = matrix.row_indices[i];
1319 let col = matrix.col_indices[i];
1320 if row == col && row < diag.len() {
1321 diag[row] = matrix.values[i].clone();
1322 }
1323 }
1324 }
1325 SparseFormat::Csr => {
1326 for row in 0..matrix.rows.min(diag.len()) {
1327 let start = matrix.row_indices[row];
1328 let end = matrix.row_indices[row + 1];
1329
1330 for idx in start..end {
1331 let col = matrix.col_indices[idx];
1332 if col == row {
1333 diag[row] = matrix.values[idx].clone();
1334 break;
1335 } else if col > row {
1336 break; }
1338 }
1339 }
1340 }
1341 _ => {
1342 }
1345 }
1346
1347 diag
1348 }
1349}
1350
1351#[cfg(test)]
1352mod tests {
1353 use super::*;
1354
1355 #[test]
1356 fn test_sparse_matrix_creation() {
1357 let mut matrix = SparseMatrix::<f32>::new_coo(3, 3);
1358 assert_eq!(matrix.rows, 3);
1359 assert_eq!(matrix.cols, 3);
1360 assert_eq!(matrix.nnz, 0);
1361 assert_eq!(matrix.format, SparseFormat::Coo);
1362
1363 matrix.insert_coo(0, 0, 1.0).unwrap();
1365 matrix.insert_coo(1, 1, 2.0).unwrap();
1366 matrix.insert_coo(2, 2, 3.0).unwrap();
1367
1368 assert_eq!(matrix.nnz, 3);
1369 assert_eq!(matrix.sparsity_ratio(), 3.0 / 9.0);
1370 assert!(matrix.is_sparse());
1371 }
1372
1373 #[test]
1374 fn test_coo_to_csr_conversion() {
1375 let mut coo = SparseMatrix::<f32>::new_coo(3, 3);
1376 coo.insert_coo(0, 0, 1.0).unwrap();
1377 coo.insert_coo(0, 2, 2.0).unwrap();
1378 coo.insert_coo(1, 1, 3.0).unwrap();
1379 coo.insert_coo(2, 0, 4.0).unwrap();
1380 coo.insert_coo(2, 2, 5.0).unwrap();
1381
1382 let csr = coo.to_csr().unwrap();
1383 assert_eq!(csr.format, SparseFormat::Csr);
1384 assert_eq!(csr.nnz, 5);
1385
1386 assert_eq!(csr.row_indices, vec![0, 2, 3, 5]);
1388 }
1389
1390 #[test]
1391 fn test_sparse_spmv() {
1392 let device = Device::cpu().unwrap();
1393 let sparse_ops = DefaultSparseOps::new(device);
1394
1395 let mut matrix = SparseMatrix::<f32>::new_coo(3, 3);
1397 matrix.insert_coo(0, 0, 2.0).unwrap();
1398 matrix.insert_coo(1, 1, 3.0).unwrap();
1399 matrix.insert_coo(2, 2, 4.0).unwrap();
1400
1401 let csr_matrix = matrix.to_csr().unwrap();
1403
1404 let x = vec![1.0, 2.0, 3.0];
1405 let mut y = vec![0.0; 3];
1406
1407 sparse_ops.spmv(&csr_matrix, &x, &mut y).unwrap();
1408
1409 assert_eq!(y, vec![2.0, 6.0, 12.0]);
1410 }
1411
1412 #[test]
1413 fn test_sparse_to_dense() {
1414 let device = Device::cpu().unwrap();
1415 let sparse_ops = DefaultSparseOps::new(device);
1416
1417 let mut matrix = SparseMatrix::<f32>::new_coo(2, 2);
1418 matrix.insert_coo(0, 0, 1.0).unwrap();
1419 matrix.insert_coo(1, 1, 2.0).unwrap();
1420
1421 let dense = sparse_ops.to_dense(&matrix).unwrap();
1422 assert_eq!(dense, vec![1.0, 0.0, 0.0, 2.0]);
1423 }
1424
1425 #[test]
1426 fn test_sparse_from_dense() {
1427 let device = Device::cpu().unwrap();
1428 let sparse_ops = DefaultSparseOps::new(device);
1429
1430 let dense = vec![1.0, 0.0, 0.0, 2.0];
1431 let sparse = sparse_ops.from_dense(&dense, 2, 2, 0.1).unwrap();
1432
1433 assert_eq!(sparse.nnz, 2);
1434 assert_eq!(sparse.sparsity_ratio(), 0.5);
1435 }
1436
1437 #[test]
1438 fn test_sparse_transpose() {
1439 let device = Device::cpu().unwrap();
1440 let sparse_ops = DefaultSparseOps::new(device);
1441
1442 let mut matrix = SparseMatrix::<f32>::new_coo(2, 3);
1443 matrix.insert_coo(0, 1, 1.0).unwrap();
1444 matrix.insert_coo(1, 2, 2.0).unwrap();
1445
1446 let transposed = sparse_ops.transpose(&matrix).unwrap();
1447
1448 assert_eq!(transposed.rows, 3);
1449 assert_eq!(transposed.cols, 2);
1450 assert_eq!(transposed.nnz, 2);
1451
1452 assert_eq!(transposed.row_indices, vec![1, 2]); assert_eq!(transposed.col_indices, vec![0, 1]); }
1456
1457 #[test]
1458 fn test_memory_usage_estimation() {
1459 let rows = 1000;
1460 let cols = 1000;
1461 let nnz = 10000; let coo_memory =
1464 SparseFormatConverter::estimate_memory_usage(rows, cols, nnz, SparseFormat::Coo);
1465 let csr_memory =
1466 SparseFormatConverter::estimate_memory_usage(rows, cols, nnz, SparseFormat::Csr);
1467 let dense_memory =
1468 SparseFormatConverter::estimate_memory_usage(rows, cols, nnz, SparseFormat::Dense);
1469
1470 assert!(dense_memory > coo_memory);
1472 assert!(dense_memory > csr_memory);
1473
1474 assert!(csr_memory < coo_memory);
1476 }
1477
1478 #[test]
1479 fn test_format_selection() {
1480 let matrix = SparseMatrix::<f32>::new_coo(100, 100);
1481
1482 let spmv_format =
1483 SparseFormatConverter::choose_optimal_format(&matrix, SparseOperation::SpMV);
1484 assert_eq!(spmv_format, SparseFormat::Csr);
1485
1486 let add_format =
1487 SparseFormatConverter::choose_optimal_format(&matrix, SparseOperation::Addition);
1488 assert_eq!(add_format, SparseFormat::Coo);
1489 }
1490}