1use crate::Scalar;
7use crate::error::{CoreError, Result};
8use crate::tensor::Tensor;
9
10#[cfg_attr(
29 feature = "serde-support",
30 derive(serde::Serialize, serde::Deserialize)
31)]
32#[derive(Debug, Clone)]
33pub struct CooMatrix<T: Scalar> {
34 rows: Vec<usize>,
35 cols: Vec<usize>,
36 values: Vec<T>,
37 nrows: usize,
38 ncols: usize,
39}
40
41impl<T: Scalar> CooMatrix<T> {
42 pub fn new(nrows: usize, ncols: usize) -> Self {
53 Self {
54 rows: Vec::new(),
55 cols: Vec::new(),
56 values: Vec::new(),
57 nrows,
58 ncols,
59 }
60 }
61
62 pub fn from_triplets(
75 nrows: usize,
76 ncols: usize,
77 rows: Vec<usize>,
78 cols: Vec<usize>,
79 values: Vec<T>,
80 ) -> Result<Self> {
81 if rows.len() != cols.len() || rows.len() != values.len() {
82 return Err(CoreError::InvalidArgument {
83 reason: "rows, cols, and values must have the same length",
84 });
85 }
86 for (&r, &c) in rows.iter().zip(cols.iter()) {
87 if r >= nrows || c >= ncols {
88 return Err(CoreError::InvalidArgument {
89 reason: "index out of bounds for matrix dimensions",
90 });
91 }
92 }
93 Ok(Self {
94 rows,
95 cols,
96 values,
97 nrows,
98 ncols,
99 })
100 }
101
102 pub fn push(&mut self, row: usize, col: usize, value: T) -> Result<()> {
113 if row >= self.nrows || col >= self.ncols {
114 return Err(CoreError::InvalidArgument {
115 reason: "index out of bounds for matrix dimensions",
116 });
117 }
118 self.rows.push(row);
119 self.cols.push(col);
120 self.values.push(value);
121 Ok(())
122 }
123
124 #[inline]
126 pub fn nrows(&self) -> usize {
127 self.nrows
128 }
129
130 #[inline]
132 pub fn ncols(&self) -> usize {
133 self.ncols
134 }
135
136 #[inline]
138 pub fn nnz(&self) -> usize {
139 self.values.len()
140 }
141
142 #[inline]
144 pub fn shape(&self) -> (usize, usize) {
145 (self.nrows, self.ncols)
146 }
147
148 pub fn to_dense(&self) -> Tensor<T> {
163 let mut data = vec![T::zero(); self.nrows * self.ncols];
164 for ((&r, &c), &v) in self
165 .rows
166 .iter()
167 .zip(self.cols.iter())
168 .zip(self.values.iter())
169 {
170 data[r * self.ncols + c] += v;
171 }
172 Tensor::from_vec(data, vec![self.nrows, self.ncols])
174 .expect("dense data length equals nrows*ncols by construction")
175 }
176
177 pub fn to_csr(&self) -> CsrMatrix<T> {
191 let mut row_counts = vec![0usize; self.nrows + 1];
193 for &r in &self.rows {
194 row_counts[r + 1] += 1;
195 }
196 for i in 1..=self.nrows {
198 row_counts[i] += row_counts[i - 1];
199 }
200
201 let nnz = self.values.len();
202 let mut col_idx = vec![0usize; nnz];
203 let mut values = vec![T::zero(); nnz];
204 let mut offset = row_counts.clone();
205
206 for ((&r, &c), &v) in self
207 .rows
208 .iter()
209 .zip(self.cols.iter())
210 .zip(self.values.iter())
211 {
212 let pos = offset[r];
213 col_idx[pos] = c;
214 values[pos] = v;
215 offset[r] += 1;
216 }
217
218 let mut result = CsrMatrix {
220 row_ptr: row_counts,
221 col_idx,
222 values,
223 nrows: self.nrows,
224 ncols: self.ncols,
225 };
226 result.sort_and_sum_duplicates();
227 result
228 }
229
230 pub fn to_csc(&self) -> CscMatrix<T> {
244 let mut col_counts = vec![0usize; self.ncols + 1];
246 for &c in &self.cols {
247 col_counts[c + 1] += 1;
248 }
249 for i in 1..=self.ncols {
250 col_counts[i] += col_counts[i - 1];
251 }
252
253 let nnz = self.values.len();
254 let mut row_idx = vec![0usize; nnz];
255 let mut values = vec![T::zero(); nnz];
256 let mut offset = col_counts.clone();
257
258 for ((&r, &c), &v) in self
259 .rows
260 .iter()
261 .zip(self.cols.iter())
262 .zip(self.values.iter())
263 {
264 let pos = offset[c];
265 row_idx[pos] = r;
266 values[pos] = v;
267 offset[c] += 1;
268 }
269
270 let mut result = CscMatrix {
271 col_ptr: col_counts,
272 row_idx,
273 values,
274 nrows: self.nrows,
275 ncols: self.ncols,
276 };
277 result.sort_and_sum_duplicates();
278 result
279 }
280}
281
282#[cfg_attr(
300 feature = "serde-support",
301 derive(serde::Serialize, serde::Deserialize)
302)]
303#[derive(Debug, Clone)]
304pub struct CsrMatrix<T: Scalar> {
305 row_ptr: Vec<usize>,
306 col_idx: Vec<usize>,
307 values: Vec<T>,
308 nrows: usize,
309 ncols: usize,
310}
311
312impl<T: Scalar> CsrMatrix<T> {
313 pub fn new(nrows: usize, ncols: usize) -> Self {
323 Self {
324 row_ptr: vec![0; nrows + 1],
325 col_idx: Vec::new(),
326 values: Vec::new(),
327 nrows,
328 ncols,
329 }
330 }
331
332 pub fn from_triplets(
345 nrows: usize,
346 ncols: usize,
347 rows: Vec<usize>,
348 cols: Vec<usize>,
349 values: Vec<T>,
350 ) -> Result<Self> {
351 let coo = CooMatrix::from_triplets(nrows, ncols, rows, cols, values)?;
352 Ok(coo.to_csr())
353 }
354
355 pub fn from_dense(tensor: &Tensor<T>) -> Result<Self> {
367 if tensor.ndim() != 2 {
368 return Err(CoreError::InvalidArgument {
369 reason: "from_dense requires a 2-D tensor",
370 });
371 }
372 let nrows = tensor.shape()[0];
373 let ncols = tensor.shape()[1];
374 let data = tensor.as_slice();
375
376 let mut row_ptr = vec![0usize; nrows + 1];
377 let mut col_idx = Vec::new();
378 let mut values = Vec::new();
379
380 for r in 0..nrows {
381 for c in 0..ncols {
382 let v = data[r * ncols + c];
383 if v != T::zero() {
384 col_idx.push(c);
385 values.push(v);
386 }
387 }
388 row_ptr[r + 1] = values.len();
389 }
390
391 Ok(Self {
392 row_ptr,
393 col_idx,
394 values,
395 nrows,
396 ncols,
397 })
398 }
399
400 #[inline]
402 pub fn nrows(&self) -> usize {
403 self.nrows
404 }
405
406 #[inline]
408 pub fn ncols(&self) -> usize {
409 self.ncols
410 }
411
412 #[inline]
414 pub fn nnz(&self) -> usize {
415 self.values.len()
416 }
417
418 #[inline]
420 pub fn shape(&self) -> (usize, usize) {
421 (self.nrows, self.ncols)
422 }
423
424 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
437 if row >= self.nrows || col >= self.ncols {
438 return None;
439 }
440 let start = self.row_ptr[row];
441 let end = self.row_ptr[row + 1];
442 self.col_idx[start..end]
443 .binary_search(&col)
444 .ok()
445 .map(|pos| &self.values[start + pos])
446 }
447
448 pub fn to_dense(&self) -> Tensor<T> {
461 let mut data = vec![T::zero(); self.nrows * self.ncols];
462 for r in 0..self.nrows {
463 let start = self.row_ptr[r];
464 let end = self.row_ptr[r + 1];
465 for idx in start..end {
466 let c = self.col_idx[idx];
467 data[r * self.ncols + c] = self.values[idx];
468 }
469 }
470 Tensor::from_vec(data, vec![self.nrows, self.ncols])
472 .expect("dense data length equals nrows*ncols by construction")
473 }
474
475 pub fn matvec(&self, x: &Tensor<T>) -> Result<Tensor<T>> {
492 if x.ndim() != 1 || x.numel() != self.ncols {
493 return Err(CoreError::DimensionMismatch {
494 expected: vec![self.ncols],
495 got: x.shape().to_vec(),
496 });
497 }
498 let xdata = x.as_slice();
499 let mut result = vec![T::zero(); self.nrows];
500
501 for (r, dest) in result.iter_mut().enumerate() {
502 let start = self.row_ptr[r];
503 let end = self.row_ptr[r + 1];
504 let mut acc = T::zero();
505 for idx in start..end {
506 acc += self.values[idx] * xdata[self.col_idx[idx]];
507 }
508 *dest = acc;
509 }
510
511 Tensor::from_vec(result, vec![self.nrows])
512 }
513
514 pub fn transpose(&self) -> CscMatrix<T> {
528 CscMatrix {
529 col_ptr: self.row_ptr.clone(),
530 row_idx: self.col_idx.clone(),
531 values: self.values.clone(),
532 nrows: self.ncols,
533 ncols: self.nrows,
534 }
535 }
536
537 pub fn to_coo(&self) -> CooMatrix<T> {
550 let mut rows = Vec::with_capacity(self.nnz());
551 let mut cols = Vec::with_capacity(self.nnz());
552 let mut values = Vec::with_capacity(self.nnz());
553
554 for r in 0..self.nrows {
555 let start = self.row_ptr[r];
556 let end = self.row_ptr[r + 1];
557 for idx in start..end {
558 rows.push(r);
559 cols.push(self.col_idx[idx]);
560 values.push(self.values[idx]);
561 }
562 }
563
564 CooMatrix {
565 rows,
566 cols,
567 values,
568 nrows: self.nrows,
569 ncols: self.ncols,
570 }
571 }
572
573 pub fn to_csc(&self) -> CscMatrix<T> {
586 self.to_coo().to_csc()
587 }
588
589 fn sort_and_sum_duplicates(&mut self) {
591 for r in 0..self.nrows {
592 let start = self.row_ptr[r];
593 let end = self.row_ptr[r + 1];
594 if start == end {
595 continue;
596 }
597
598 let len = end - start;
600 let mut perm: Vec<usize> = (0..len).collect();
601 perm.sort_unstable_by_key(|&i| self.col_idx[start + i]);
602
603 let old_cols: Vec<usize> = self.col_idx[start..end].to_vec();
604 let old_vals: Vec<T> = self.values[start..end].to_vec();
605 for (j, &p) in perm.iter().enumerate() {
606 self.col_idx[start + j] = old_cols[p];
607 self.values[start + j] = old_vals[p];
608 }
609
610 let mut write = start;
612 for read in (start + 1)..end {
613 if self.col_idx[read] == self.col_idx[write] {
614 let v = self.values[read];
615 self.values[write] += v;
616 } else {
617 write += 1;
618 self.col_idx[write] = self.col_idx[read];
619 self.values[write] = self.values[read];
620 }
621 }
622 let new_end = write + 1;
623
624 if new_end < end {
626 let removed = end - new_end;
627 let total_nnz = self.col_idx.len();
628 self.col_idx.copy_within(end..total_nnz, new_end);
629 self.col_idx.truncate(total_nnz - removed);
630 let total_vals = self.values.len();
631 self.values.copy_within(end..total_vals, new_end);
632 self.values.truncate(total_vals - removed);
633
634 for i in (r + 1)..=self.nrows {
635 self.row_ptr[i] -= removed;
636 }
637 }
638 }
639 }
640}
641
642#[cfg_attr(
659 feature = "serde-support",
660 derive(serde::Serialize, serde::Deserialize)
661)]
662#[derive(Debug, Clone)]
663pub struct CscMatrix<T: Scalar> {
664 col_ptr: Vec<usize>,
665 row_idx: Vec<usize>,
666 values: Vec<T>,
667 nrows: usize,
668 ncols: usize,
669}
670
671impl<T: Scalar> CscMatrix<T> {
672 pub fn new(nrows: usize, ncols: usize) -> Self {
682 Self {
683 col_ptr: vec![0; ncols + 1],
684 row_idx: Vec::new(),
685 values: Vec::new(),
686 nrows,
687 ncols,
688 }
689 }
690
691 pub fn from_triplets(
704 nrows: usize,
705 ncols: usize,
706 rows: Vec<usize>,
707 cols: Vec<usize>,
708 values: Vec<T>,
709 ) -> Result<Self> {
710 let coo = CooMatrix::from_triplets(nrows, ncols, rows, cols, values)?;
711 Ok(coo.to_csc())
712 }
713
714 #[inline]
716 pub fn nrows(&self) -> usize {
717 self.nrows
718 }
719
720 #[inline]
722 pub fn ncols(&self) -> usize {
723 self.ncols
724 }
725
726 #[inline]
728 pub fn nnz(&self) -> usize {
729 self.values.len()
730 }
731
732 #[inline]
734 pub fn shape(&self) -> (usize, usize) {
735 (self.nrows, self.ncols)
736 }
737
738 pub fn to_dense(&self) -> Tensor<T> {
751 let mut data = vec![T::zero(); self.nrows * self.ncols];
752 for c in 0..self.ncols {
753 let start = self.col_ptr[c];
754 let end = self.col_ptr[c + 1];
755 for idx in start..end {
756 let r = self.row_idx[idx];
757 data[r * self.ncols + c] = self.values[idx];
758 }
759 }
760 Tensor::from_vec(data, vec![self.nrows, self.ncols])
762 .expect("dense data length equals nrows*ncols by construction")
763 }
764
765 pub fn matvec(&self, x: &Tensor<T>) -> Result<Tensor<T>> {
769 if x.ndim() != 1 || x.numel() != self.ncols {
770 return Err(CoreError::DimensionMismatch {
771 expected: vec![self.ncols],
772 got: x.shape().to_vec(),
773 });
774 }
775 let xdata = x.as_slice();
776 let mut result = vec![T::zero(); self.nrows];
777
778 for (c, &xc) in xdata.iter().enumerate().take(self.ncols) {
779 let start = self.col_ptr[c];
780 let end = self.col_ptr[c + 1];
781 for idx in start..end {
782 result[self.row_idx[idx]] += self.values[idx] * xc;
783 }
784 }
785
786 Tensor::from_vec(result, vec![self.nrows])
787 }
788
789 pub fn transpose(&self) -> CsrMatrix<T> {
791 CsrMatrix {
792 row_ptr: self.col_ptr.clone(),
793 col_idx: self.row_idx.clone(),
794 values: self.values.clone(),
795 nrows: self.ncols,
796 ncols: self.nrows,
797 }
798 }
799
800 pub fn to_coo(&self) -> CooMatrix<T> {
802 let mut rows = Vec::with_capacity(self.nnz());
803 let mut cols = Vec::with_capacity(self.nnz());
804 let mut values = Vec::with_capacity(self.nnz());
805
806 for c in 0..self.ncols {
807 let start = self.col_ptr[c];
808 let end = self.col_ptr[c + 1];
809 for idx in start..end {
810 rows.push(self.row_idx[idx]);
811 cols.push(c);
812 values.push(self.values[idx]);
813 }
814 }
815
816 CooMatrix {
817 rows,
818 cols,
819 values,
820 nrows: self.nrows,
821 ncols: self.ncols,
822 }
823 }
824
825 pub fn to_csr(&self) -> CsrMatrix<T> {
827 self.to_coo().to_csr()
828 }
829
830 fn sort_and_sum_duplicates(&mut self) {
832 for c in 0..self.ncols {
833 let start = self.col_ptr[c];
834 let end = self.col_ptr[c + 1];
835 if start == end {
836 continue;
837 }
838
839 let len = end - start;
840 let mut perm: Vec<usize> = (0..len).collect();
841 perm.sort_unstable_by_key(|&i| self.row_idx[start + i]);
842
843 let old_rows: Vec<usize> = self.row_idx[start..end].to_vec();
844 let old_vals: Vec<T> = self.values[start..end].to_vec();
845 for (j, &p) in perm.iter().enumerate() {
846 self.row_idx[start + j] = old_rows[p];
847 self.values[start + j] = old_vals[p];
848 }
849
850 let mut write = start;
852 for read in (start + 1)..end {
853 if self.row_idx[read] == self.row_idx[write] {
854 let v = self.values[read];
855 self.values[write] += v;
856 } else {
857 write += 1;
858 self.row_idx[write] = self.row_idx[read];
859 self.values[write] = self.values[read];
860 }
861 }
862 let new_end = write + 1;
863
864 if new_end < end {
865 let removed = end - new_end;
866 let total_idx = self.row_idx.len();
867 self.row_idx.copy_within(end..total_idx, new_end);
868 self.row_idx.truncate(total_idx - removed);
869 let total_vals = self.values.len();
870 self.values.copy_within(end..total_vals, new_end);
871 self.values.truncate(total_vals - removed);
872
873 for i in (c + 1)..=self.ncols {
874 self.col_ptr[i] -= removed;
875 }
876 }
877 }
878 }
879}
880
881#[cfg(test)]
882#[allow(clippy::float_cmp)]
883mod tests {
884 use super::*;
885
886 fn sample_coo() -> CooMatrix<f64> {
891 CooMatrix::from_triplets(
892 3,
893 3,
894 vec![0, 0, 1, 2, 2],
895 vec![0, 2, 1, 0, 2],
896 vec![1.0, 2.0, 3.0, 4.0, 5.0],
897 )
898 .unwrap()
899 }
900
901 #[test]
902 fn test_coo_to_dense() {
903 let coo = sample_coo();
904 let dense = coo.to_dense();
905 assert_eq!(dense.shape(), &[3, 3]);
906 assert_eq!(
907 dense.as_slice(),
908 &[1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0]
909 );
910 }
911
912 #[test]
913 fn test_csr_from_dense_roundtrip() {
914 let dense = Tensor::from_vec(
915 vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0],
916 vec![3, 3],
917 )
918 .unwrap();
919 let csr = CsrMatrix::from_dense(&dense).unwrap();
920 assert_eq!(csr.nnz(), 5);
921 let back = csr.to_dense();
922 assert_eq!(dense, back);
923 }
924
925 #[test]
926 fn test_csr_matvec() {
927 let csr = sample_coo().to_csr();
928 let x = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
929 let y = csr.matvec(&x).unwrap();
930 assert_eq!(y.as_slice(), &[7.0, 6.0, 19.0]);
932 }
933
934 #[test]
935 fn test_csc_matvec() {
936 let csc = sample_coo().to_csc();
937 let x = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
938 let y = csc.matvec(&x).unwrap();
939 assert_eq!(y.as_slice(), &[7.0, 6.0, 19.0]);
940 }
941
942 #[test]
943 fn test_coo_csr_csc_dense_roundtrip() {
944 let coo = sample_coo();
945 let expected = coo.to_dense();
946
947 let csr = coo.to_csr();
948 assert_eq!(csr.to_dense(), expected);
949
950 let csc = csr.to_csc();
951 assert_eq!(csc.to_dense(), expected);
952
953 let coo2 = csc.to_coo();
954 assert_eq!(coo2.to_dense(), expected);
955 }
956
957 #[test]
958 fn test_identity_matrix() {
959 let csr = CsrMatrix::from_dense(&Tensor::<f64>::eye(4)).unwrap();
960 assert_eq!(csr.nnz(), 4);
961 let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
962 let y = csr.matvec(&x).unwrap();
963 assert_eq!(y, x);
964 }
965
966 #[test]
967 fn test_empty_matrix() {
968 let csr = CsrMatrix::<f64>::new(3, 3);
969 assert_eq!(csr.nnz(), 0);
970 let dense = csr.to_dense();
971 assert_eq!(dense, Tensor::<f64>::zeros(vec![3, 3]));
972 }
973
974 #[test]
975 fn test_dimension_mismatch() {
976 let csr = sample_coo().to_csr();
977 let x = Tensor::from_vec(vec![1.0, 2.0], vec![2]).unwrap();
978 assert!(csr.matvec(&x).is_err());
979 }
980
981 #[test]
982 fn test_duplicate_coo_entries_summed() {
983 let coo = CooMatrix::from_triplets(2, 2, vec![0, 0, 1], vec![0, 0, 1], vec![1.0, 2.0, 5.0])
985 .unwrap();
986 let csr = coo.to_csr();
987 assert_eq!(*csr.get(0, 0).unwrap(), 3.0);
988 assert_eq!(*csr.get(1, 1).unwrap(), 5.0);
989 assert_eq!(csr.nnz(), 2);
990 }
991
992 #[test]
993 fn test_csr_transpose() {
994 let csr = sample_coo().to_csr();
995 let csc = csr.transpose();
996 assert_eq!(csc.nrows(), 3);
998 assert_eq!(csc.ncols(), 3);
999 let orig = csr.to_dense();
1001 let trans = csc.to_dense();
1002 for i in 0..3 {
1004 for j in 0..3 {
1005 assert_eq!(*trans.get(&[i, j]).unwrap(), *orig.get(&[j, i]).unwrap());
1006 }
1007 }
1008 }
1009
1010 #[test]
1011 fn test_csr_get() {
1012 let csr = sample_coo().to_csr();
1013 assert_eq!(*csr.get(0, 0).unwrap(), 1.0);
1014 assert_eq!(*csr.get(0, 2).unwrap(), 2.0);
1015 assert!(csr.get(0, 1).is_none()); assert!(csr.get(5, 0).is_none()); }
1018
1019 #[test]
1020 fn test_coo_push() {
1021 let mut coo = CooMatrix::<f64>::new(2, 2);
1022 coo.push(0, 0, 1.0).unwrap();
1023 coo.push(1, 1, 2.0).unwrap();
1024 assert_eq!(coo.nnz(), 2);
1025 assert!(coo.push(2, 0, 1.0).is_err()); }
1027}