1use crate::error::{SparseError, SparseResult};
7use scirs2_core::numeric::{Float, SparseElement, Zero};
8use scirs2_core::GpuDataType;
9use std::cmp::PartialEq;
10
11#[derive(Clone, Debug)]
16pub struct CsrMatrix<T> {
17 rows: usize,
19 cols: usize,
21 pub indptr: Vec<usize>,
23 pub indices: Vec<usize>,
25 pub data: Vec<T>,
27}
28
29impl<T> CsrMatrix<T>
30where
31 T: Clone + Copy + Zero + PartialEq + SparseElement,
32{
33 pub fn get(&self, row: usize, col: usize) -> T {
35 if row >= self.rows || col >= self.cols {
37 return T::sparse_zero();
38 }
39
40 for j in self.indptr[row]..self.indptr[row + 1] {
42 if self.indices[j] == col {
43 return self.data[j];
44 }
45 }
46
47 T::sparse_zero()
49 }
50
51 pub fn get_triplets(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
53 let mut rows = Vec::new();
54 let mut cols = Vec::new();
55 let mut values = Vec::new();
56
57 for i in 0..self.rows {
58 for j in self.indptr[i]..self.indptr[i + 1] {
59 rows.push(i);
60 cols.push(self.indices[j]);
61 values.push(self.data[j]);
62 }
63 }
64
65 (rows, cols, values)
66 }
67 pub fn new(
94 data: Vec<T>,
95 rowindices: Vec<usize>,
96 colindices: Vec<usize>,
97 shape: (usize, usize),
98 ) -> SparseResult<Self> {
99 if data.len() != rowindices.len() || data.len() != colindices.len() {
101 return Err(SparseError::DimensionMismatch {
102 expected: data.len(),
103 found: std::cmp::min(rowindices.len(), colindices.len()),
104 });
105 }
106
107 let (rows, cols) = shape;
108
109 if rowindices.iter().any(|&i| i >= rows) {
111 return Err(SparseError::ValueError(
112 "Row index out of bounds".to_string(),
113 ));
114 }
115
116 if colindices.iter().any(|&i| i >= cols) {
117 return Err(SparseError::ValueError(
118 "Column index out of bounds".to_string(),
119 ));
120 }
121
122 let mut triplets: Vec<(usize, usize, T)> = rowindices
125 .into_iter()
126 .zip(colindices)
127 .zip(data)
128 .map(|((r, c), v)| (r, c, v))
129 .collect();
130 triplets.sort_by_key(|&(r, c_, _)| (r, c_));
131
132 let nnz = triplets.len();
134 let mut indptr = vec![0; rows + 1];
135 let mut indices = Vec::with_capacity(nnz);
136 let mut data_out = Vec::with_capacity(nnz);
137
138 for &(r_, _, _) in &triplets {
140 indptr[r_ + 1] += 1;
141 }
142
143 for i in 1..=rows {
145 indptr[i] += indptr[i - 1];
146 }
147
148 for (_r, c, v) in triplets {
150 indices.push(c);
151 data_out.push(v);
152 }
153
154 Ok(CsrMatrix {
155 rows,
156 cols,
157 indptr,
158 indices,
159 data: data_out,
160 })
161 }
162
163 pub fn from_triplets(
194 nrows: usize,
195 ncols: usize,
196 row_indices: Vec<usize>,
197 col_indices: Vec<usize>,
198 values: Vec<T>,
199 ) -> SparseResult<Self> {
200 Self::new(values, row_indices, col_indices, (nrows, ncols))
201 }
202
203 pub fn try_from_triplets(
238 nrows: usize,
239 ncols: usize,
240 triplets: &[(usize, usize, T)],
241 ) -> SparseResult<Self> {
242 let mut row_indices = Vec::with_capacity(triplets.len());
243 let mut col_indices = Vec::with_capacity(triplets.len());
244 let mut values = Vec::with_capacity(triplets.len());
245
246 for &(r, c, v) in triplets {
247 row_indices.push(r);
248 col_indices.push(c);
249 values.push(v);
250 }
251
252 Self::from_triplets(nrows, ncols, row_indices, col_indices, values)
253 }
254
255 pub fn from_raw_csr(
268 data: Vec<T>,
269 indptr: Vec<usize>,
270 indices: Vec<usize>,
271 shape: (usize, usize),
272 ) -> SparseResult<Self> {
273 let (rows, cols) = shape;
274
275 if indptr.len() != rows + 1 {
277 return Err(SparseError::DimensionMismatch {
278 expected: rows + 1,
279 found: indptr.len(),
280 });
281 }
282
283 if data.len() != indices.len() {
284 return Err(SparseError::DimensionMismatch {
285 expected: data.len(),
286 found: indices.len(),
287 });
288 }
289
290 for i in 1..indptr.len() {
292 if indptr[i] < indptr[i - 1] {
293 return Err(SparseError::ValueError(
294 "Row pointer array must be monotonically increasing".to_string(),
295 ));
296 }
297 }
298
299 if indptr[rows] != data.len() {
301 return Err(SparseError::ValueError(
302 "Last row pointer entry must match data length".to_string(),
303 ));
304 }
305
306 if indices.iter().any(|&i| i >= cols) {
308 return Err(SparseError::ValueError(
309 "Column index out of bounds".to_string(),
310 ));
311 }
312
313 Ok(CsrMatrix {
314 rows,
315 cols,
316 indptr,
317 indices,
318 data,
319 })
320 }
321
322 pub fn empty(shape: (usize, usize)) -> Self {
332 let (rows, cols) = shape;
333 let indptr = vec![0; rows + 1];
334
335 CsrMatrix {
336 rows,
337 cols,
338 indptr,
339 indices: Vec::new(),
340 data: Vec::new(),
341 }
342 }
343
344 pub fn rows(&self) -> usize {
346 self.rows
347 }
348
349 pub fn cols(&self) -> usize {
351 self.cols
352 }
353
354 pub fn shape(&self) -> (usize, usize) {
356 (self.rows, self.cols)
357 }
358
359 pub fn nnz(&self) -> usize {
361 self.data.len()
362 }
363
364 pub fn to_dense(&self) -> Vec<Vec<T>>
366 where
367 T: Zero + Copy + SparseElement,
368 {
369 let mut result = vec![vec![T::sparse_zero(); self.cols]; self.rows];
370
371 for (row_idx, row) in result.iter_mut().enumerate() {
372 for j in self.indptr[row_idx]..self.indptr[row_idx + 1] {
373 let col_idx = self.indices[j];
374 row[col_idx] = self.data[j];
375 }
376 }
377
378 result
379 }
380
381 pub fn transpose(&self) -> Self {
383 let mut col_counts = vec![0; self.cols];
385 for &col in &self.indices {
386 col_counts[col] += 1;
387 }
388
389 let mut col_ptrs = vec![0; self.cols + 1];
391 for i in 0..self.cols {
392 col_ptrs[i + 1] = col_ptrs[i] + col_counts[i];
393 }
394
395 let nnz = self.nnz();
397 let mut indices_t = vec![0; nnz];
398 let mut data_t = vec![T::sparse_zero(); nnz];
399 let mut col_counts = vec![0; self.cols];
400
401 for row in 0..self.rows {
402 for j in self.indptr[row]..self.indptr[row + 1] {
403 let col = self.indices[j];
404 let dest = col_ptrs[col] + col_counts[col];
405
406 indices_t[dest] = row;
407 data_t[dest] = self.data[j];
408 col_counts[col] += 1;
409 }
410 }
411
412 CsrMatrix {
413 rows: self.cols,
414 cols: self.rows,
415 indptr: col_ptrs,
416 indices: indices_t,
417 data: data_t,
418 }
419 }
420}
421
422impl<
423 T: Clone
424 + Copy
425 + std::ops::AddAssign
426 + std::ops::MulAssign
427 + std::cmp::PartialEq
428 + std::fmt::Debug
429 + scirs2_core::numeric::Zero
430 + std::ops::Add<Output = T>
431 + std::ops::Mul<Output = T>
432 + SparseElement,
433 > CsrMatrix<T>
434{
435 pub fn is_symmetric(&self) -> bool {
441 if self.rows != self.cols {
442 return false;
443 }
444
445 let transposed = self.transpose();
447
448 if self.nnz() != transposed.nnz() {
450 return false;
451 }
452
453 for row in 0..self.rows {
455 let self_start = self.indptr[row];
456 let self_end = self.indptr[row + 1];
457 let trans_start = transposed.indptr[row];
458 let trans_end = transposed.indptr[row + 1];
459
460 if self_end - self_start != trans_end - trans_start {
461 return false;
462 }
463
464 let mut self_entries: Vec<(usize, &T)> = (self_start..self_end)
466 .map(|j| (self.indices[j], &self.data[j]))
467 .collect();
468 self_entries.sort_by_key(|(col_, _)| *col_);
469
470 let mut trans_entries: Vec<(usize, &T)> = (trans_start..trans_end)
471 .map(|j| (transposed.indices[j], &transposed.data[j]))
472 .collect();
473 trans_entries.sort_by_key(|(col_, _)| *col_);
474
475 for i in 0..self_entries.len() {
477 if self_entries[i].0 != trans_entries[i].0
478 || self_entries[i].1 != trans_entries[i].1
479 {
480 return false;
481 }
482 }
483 }
484
485 true
486 }
487
488 pub fn matmul(&self, other: &CsrMatrix<T>) -> SparseResult<CsrMatrix<T>> {
498 if self.cols != other.rows {
499 return Err(SparseError::DimensionMismatch {
500 expected: self.cols,
501 found: other.rows,
502 });
503 }
504
505 let a_dense = self.to_dense();
508 let b_dense = other.to_dense();
509
510 let m = self.rows;
511 let n = other.cols;
512 let k = self.cols;
513
514 let mut c_dense = vec![vec![T::sparse_zero(); n]; m];
515
516 for (i, c_row) in c_dense.iter_mut().enumerate().take(m) {
517 for (j, val) in c_row.iter_mut().enumerate().take(n) {
518 for (l, &a_val) in a_dense[i].iter().enumerate().take(k) {
519 let prod = a_val * b_dense[l][j];
520 *val += prod;
521 }
522 }
523 }
524
525 let mut rowindices = Vec::new();
527 let mut colindices = Vec::new();
528 let mut values = Vec::new();
529
530 for (i, row) in c_dense.iter().enumerate() {
531 for (j, val) in row.iter().enumerate() {
532 if *val != T::sparse_zero() {
533 rowindices.push(i);
534 colindices.push(j);
535 values.push(*val);
536 }
537 }
538 }
539
540 CsrMatrix::new(values, rowindices, colindices, (m, n))
541 }
542
543 pub fn row_range(&self, row: usize) -> std::ops::Range<usize> {
553 assert!(row < self.rows, "Row index out of bounds");
554 self.indptr[row]..self.indptr[row + 1]
555 }
556
557 pub fn colindices(&self) -> &[usize] {
559 &self.indices
560 }
561
562 pub fn submatrix(
601 &self,
602 row_start: usize,
603 row_end: usize,
604 col_start: usize,
605 col_end: usize,
606 ) -> SparseResult<CsrMatrix<T>> {
607 let row_end = row_end.min(self.rows);
608 let col_end = col_end.min(self.cols);
609 if row_start >= row_end {
610 return Err(SparseError::ValueError(format!(
611 "submatrix: row_start ({}) >= row_end ({})",
612 row_start, row_end
613 )));
614 }
615 if col_start >= col_end {
616 return Err(SparseError::ValueError(format!(
617 "submatrix: col_start ({}) >= col_end ({})",
618 col_start, col_end
619 )));
620 }
621
622 let new_rows = row_end - row_start;
623 let new_cols = col_end - col_start;
624 let mut rows_out = Vec::new();
625 let mut cols_out = Vec::new();
626 let mut data_out = Vec::new();
627
628 for i in row_start..row_end {
629 let range = self.indptr[i]..self.indptr[i + 1];
630 for pos in range {
631 let j = self.indices[pos];
632 if j >= col_start && j < col_end {
633 rows_out.push(i - row_start);
634 cols_out.push(j - col_start);
635 data_out.push(self.data[pos]);
636 }
637 }
638 }
639
640 CsrMatrix::new(data_out, rows_out, cols_out, (new_rows, new_cols))
641 }
642
643 pub fn elementwise_mul(&self, other: &CsrMatrix<T>) -> SparseResult<CsrMatrix<T>>
681 where
682 T: std::ops::Mul<Output = T>,
683 {
684 if self.rows != other.rows || self.cols != other.cols {
685 return Err(SparseError::DimensionMismatch {
686 expected: self.rows * self.cols,
687 found: other.rows * other.cols,
688 });
689 }
690
691 let n = self.rows;
692 let nc = self.cols;
693 let mut rows_out = Vec::new();
694 let mut cols_out = Vec::new();
695 let mut data_out = Vec::new();
696
697 let mut b_row_buf: Vec<(usize, T)> = Vec::new();
703
704 for i in 0..n {
705 b_row_buf.clear();
707 let b_range = other.indptr[i]..other.indptr[i + 1];
708 for pos in b_range {
709 b_row_buf.push((other.indices[pos], other.data[pos]));
710 }
711
712 let a_range = self.indptr[i]..self.indptr[i + 1];
714 for pos in a_range {
715 let j = self.indices[pos];
716 let a_val = self.data[pos];
717 if let Some(&(_, b_val)) = b_row_buf.iter().find(|&&(bj, _)| bj == j) {
719 let product = a_val * b_val;
720 if product != T::sparse_zero() {
721 rows_out.push(i);
722 cols_out.push(j);
723 data_out.push(product);
724 }
725 }
726 }
727 }
728
729 CsrMatrix::new(data_out, rows_out, cols_out, (n, nc))
730 }
731}
732
733impl CsrMatrix<f64> {
734 pub fn dot(&self, vec: &[f64]) -> SparseResult<Vec<f64>> {
744 if vec.len() != self.cols {
745 return Err(SparseError::DimensionMismatch {
746 expected: self.cols,
747 found: vec.len(),
748 });
749 }
750
751 let mut result = vec![0.0; self.rows];
752
753 for (row_idx, result_val) in result.iter_mut().enumerate() {
754 for j in self.indptr[row_idx]..self.indptr[row_idx + 1] {
755 let col_idx = self.indices[j];
756 *result_val += self.data[j] * vec[col_idx];
757 }
758 }
759
760 Ok(result)
761 }
762
763 #[allow(dead_code)]
791 pub fn gpu_dot(&self, vec: &[f64]) -> SparseResult<Vec<f64>> {
792 let gpu_spmv = crate::gpu_spmv_implementation::GpuSpMV::new()?;
794 gpu_spmv.spmv(
795 self.rows,
796 self.cols,
797 &self.indptr,
798 &self.indices,
799 &self.data,
800 vec,
801 )
802 }
803
804 #[allow(dead_code)]
815 pub fn gpu_dot_with_backend(
816 &self,
817 vec: &[f64],
818 backend: scirs2_core::gpu::GpuBackend,
819 ) -> SparseResult<Vec<f64>> {
820 let gpu_spmv = crate::gpu_spmv_implementation::GpuSpMV::with_backend(backend)?;
822 gpu_spmv.spmv(
823 self.rows,
824 self.cols,
825 &self.indptr,
826 &self.indices,
827 &self.data,
828 vec,
829 )
830 }
831}
832
833impl<T> CsrMatrix<T>
834where
835 T: scirs2_core::numeric::Float
836 + std::fmt::Debug
837 + Copy
838 + Default
839 + GpuDataType
840 + Send
841 + Sync
842 + SparseElement
843 + std::ops::AddAssign
844 + std::ops::Mul<Output = T>
845 + 'static,
846{
847 #[allow(dead_code)]
857 pub fn gpu_dot_generic(&self, vec: &[T]) -> SparseResult<Vec<T>>
858where {
859 if vec.len() != self.cols {
861 return Err(SparseError::DimensionMismatch {
862 expected: self.cols,
863 found: vec.len(),
864 });
865 }
866
867 let mut result = vec![T::sparse_zero(); self.rows];
868
869 for (row_idx, result_val) in result.iter_mut().enumerate() {
870 let start = self.indptr[row_idx];
871 let end = self.indptr[row_idx + 1];
872
873 for idx in start..end {
874 let col = self.indices[idx];
875 *result_val += self.data[idx] * vec[col];
876 }
877 }
878
879 Ok(result)
880 }
881
882 pub fn should_use_gpu(&self) -> bool {
888 let nnz_threshold = 10000;
891 let density = self.nnz() as f64 / (self.rows * self.cols) as f64;
892
893 self.nnz() > nnz_threshold && density < 0.5
894 }
895
896 #[allow(dead_code)]
902 pub fn gpu_backend_info() -> SparseResult<(crate::gpu_ops::GpuBackend, String)> {
903 Ok((crate::gpu_ops::GpuBackend::Cpu, "CPU Fallback".to_string()))
905 }
906}
907
908#[cfg(test)]
909mod tests {
910 use super::*;
911 use approx::assert_relative_eq;
912
913 #[test]
914 fn test_csr_create() {
915 let rows = vec![0, 0, 1, 2, 2];
917 let cols = vec![0, 2, 2, 0, 1];
918 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
919 let shape = (3, 3);
920
921 let matrix = CsrMatrix::new(data, rows, cols, shape).expect("Operation failed");
922
923 assert_eq!(matrix.shape(), (3, 3));
924 assert_eq!(matrix.nnz(), 5);
925 }
926
927 #[test]
928 fn test_csr_to_dense() {
929 let rows = vec![0, 0, 1, 2, 2];
931 let cols = vec![0, 2, 2, 0, 1];
932 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
933 let shape = (3, 3);
934
935 let matrix = CsrMatrix::new(data, rows, cols, shape).expect("Operation failed");
936 let dense = matrix.to_dense();
937
938 let expected = vec![
939 vec![1.0, 0.0, 2.0],
940 vec![0.0, 0.0, 3.0],
941 vec![4.0, 5.0, 0.0],
942 ];
943
944 assert_eq!(dense, expected);
945 }
946
947 #[test]
948 fn test_csr_dot() {
949 let rows = vec![0, 0, 1, 2, 2];
951 let cols = vec![0, 2, 2, 0, 1];
952 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
953 let shape = (3, 3);
954
955 let matrix = CsrMatrix::new(data, rows, cols, shape).expect("Operation failed");
956
957 let vec = vec![1.0, 2.0, 3.0];
963 let result = matrix.dot(&vec).expect("Operation failed");
964
965 let expected = [7.0, 9.0, 14.0];
970
971 assert_eq!(result.len(), expected.len());
972 for (a, b) in result.iter().zip(expected.iter()) {
973 assert_relative_eq!(a, b, epsilon = 1e-10);
974 }
975 }
976
977 #[test]
978 fn test_csr_transpose() {
979 let rows = vec![0, 0, 1, 2, 2];
981 let cols = vec![0, 2, 2, 0, 1];
982 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
983 let shape = (3, 3);
984
985 let matrix = CsrMatrix::new(data, rows, cols, shape).expect("Operation failed");
986 let transposed = matrix.transpose();
987
988 assert_eq!(transposed.shape(), (3, 3));
989 assert_eq!(transposed.nnz(), 5);
990
991 let dense = transposed.to_dense();
992 let expected = vec![
993 vec![1.0, 0.0, 4.0],
994 vec![0.0, 0.0, 5.0],
995 vec![2.0, 3.0, 0.0],
996 ];
997
998 assert_eq!(dense, expected);
999 }
1000
1001 #[test]
1002 fn test_gpu_dot() {
1003 let rows = vec![0, 0, 1, 2, 2];
1005 let cols = vec![0, 2, 2, 0, 1];
1006 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1007 let shape = (3, 3);
1008
1009 let matrix = CsrMatrix::new(data, rows, cols, shape).expect("Operation failed");
1010 let vec = vec![1.0, 2.0, 3.0];
1011
1012 match matrix.gpu_dot(&vec) {
1014 Ok(result) => {
1015 let expected = [7.0, 9.0, 14.0];
1016 assert_eq!(result.len(), expected.len());
1017 for (a, b) in result.iter().zip(expected.iter()) {
1018 assert_relative_eq!(a, b, epsilon = 1e-10);
1019 }
1020 }
1021 Err(crate::error::SparseError::ComputationError(_))
1022 | Err(crate::error::SparseError::OperationNotSupported(_)) => {
1023 }
1025 Err(e) => panic!("Unexpected error in GPU SpMV: {:?}", e),
1026 }
1027 }
1028
1029 #[test]
1030 fn test_should_use_gpu() {
1031 let small_matrix = CsrMatrix::new(vec![1.0, 2.0], vec![0, 1], vec![0, 1], (2, 2))
1033 .expect("Operation failed");
1034 assert!(
1035 !small_matrix.should_use_gpu(),
1036 "Small matrix should not use GPU"
1037 );
1038
1039 let large_data = vec![1.0; 15000];
1041 let large_rows: Vec<usize> = (0..15000).collect();
1042 let large_cols: Vec<usize> = (0..15000).collect();
1043 let large_matrix = CsrMatrix::new(large_data, large_rows, large_cols, (15000, 15000))
1044 .expect("Operation failed");
1045 assert!(
1046 large_matrix.should_use_gpu(),
1047 "Large sparse matrix should use GPU"
1048 );
1049 }
1050
1051 #[test]
1052 fn test_gpu_backend_info() {
1053 let backend_info = CsrMatrix::<f64>::gpu_backend_info();
1054 assert!(
1055 backend_info.is_ok(),
1056 "Should be able to get GPU backend info"
1057 );
1058
1059 if let Ok((backend, name)) = backend_info {
1060 assert!(!name.is_empty(), "Backend name should not be empty");
1061 match backend {
1063 crate::gpu_ops::GpuBackend::Cuda
1064 | crate::gpu_ops::GpuBackend::OpenCL
1065 | crate::gpu_ops::GpuBackend::Metal
1066 | crate::gpu_ops::GpuBackend::Cpu
1067 | crate::gpu_ops::GpuBackend::Rocm
1068 | crate::gpu_ops::GpuBackend::Wgpu => {}
1069 #[cfg(not(feature = "gpu"))]
1070 crate::gpu_ops::GpuBackend::Vulkan => {}
1071 }
1072 }
1073 }
1074
1075 #[test]
1076 fn test_gpu_dot_generic_f32() {
1077 let rows = vec![0, 0, 1, 2, 2];
1079 let cols = vec![0, 2, 2, 0, 1];
1080 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
1081 let shape = (3, 3);
1082
1083 let matrix = CsrMatrix::new(data, rows, cols, shape).expect("Operation failed");
1084 let vec = vec![1.0f32, 2.0, 3.0];
1085
1086 match matrix.gpu_dot_generic(&vec) {
1087 Ok(result) => {
1088 let expected = [7.0f32, 9.0, 14.0];
1089 assert_eq!(result.len(), expected.len());
1090 for (a, b) in result.iter().zip(expected.iter()) {
1091 assert_relative_eq!(a, b, epsilon = 1e-6);
1092 }
1093 }
1094 Err(crate::error::SparseError::ComputationError(_))
1095 | Err(crate::error::SparseError::OperationNotSupported(_)) => {}
1096 Err(e) => panic!("Unexpected error in generic GPU SpMV: {:?}", e),
1097 }
1098 }
1099}