1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement, Zero};
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12use crate::sparray::{SparseArray, SparseSum};
13
14#[derive(Clone)]
106pub struct CsrArray<T>
107where
108 T: SparseElement + Div<Output = T> + 'static,
109{
110 data: Array1<T>,
112 indices: Array1<usize>,
114 indptr: Array1<usize>,
116 shape: (usize, usize),
118 has_sorted_indices: bool,
120}
121
122impl<T> CsrArray<T>
123where
124 T: SparseElement + Div<Output = T> + Zero + 'static,
125{
126 pub fn new(
140 data: Array1<T>,
141 indices: Array1<usize>,
142 indptr: Array1<usize>,
143 shape: (usize, usize),
144 ) -> SparseResult<Self> {
145 if data.len() != indices.len() {
147 return Err(SparseError::InconsistentData {
148 reason: "data and indices must have the same length".to_string(),
149 });
150 }
151
152 if indptr.len() != shape.0 + 1 {
153 return Err(SparseError::InconsistentData {
154 reason: format!(
155 "indptr length ({}) must be one more than the number of rows ({})",
156 indptr.len(),
157 shape.0
158 ),
159 });
160 }
161
162 if let Some(&max_idx) = indices.iter().max() {
163 if max_idx >= shape.1 {
164 return Err(SparseError::IndexOutOfBounds {
165 index: (0, max_idx),
166 shape,
167 });
168 }
169 }
170
171 if let Some((&last, &first)) = indptr.iter().next_back().zip(indptr.iter().next()) {
172 if first != 0 {
173 return Err(SparseError::InconsistentData {
174 reason: "first element of indptr must be 0".to_string(),
175 });
176 }
177
178 if last != data.len() {
179 return Err(SparseError::InconsistentData {
180 reason: format!(
181 "last element of indptr ({}) must equal data length ({})",
182 last,
183 data.len()
184 ),
185 });
186 }
187 }
188
189 let has_sorted_indices = Self::check_sorted_indices(&indices, &indptr);
190
191 Ok(Self {
192 data,
193 indices,
194 indptr,
195 shape,
196 has_sorted_indices,
197 })
198 }
199
200 pub fn from_triplets(
274 rows: &[usize],
275 cols: &[usize],
276 data: &[T],
277 shape: (usize, usize),
278 sorted: bool,
279 ) -> SparseResult<Self> {
280 if rows.len() != cols.len() || rows.len() != data.len() {
281 return Err(SparseError::InconsistentData {
282 reason: "rows, cols, and data must have the same length".to_string(),
283 });
284 }
285
286 if rows.is_empty() {
287 let indptr = Array1::zeros(shape.0 + 1);
289 return Self::new(Array1::zeros(0), Array1::zeros(0), indptr, shape);
290 }
291
292 let nnz = rows.len();
293 let mut all_data: Vec<(usize, usize, T)> = Vec::with_capacity(nnz);
294
295 for i in 0..nnz {
296 if rows[i] >= shape.0 || cols[i] >= shape.1 {
297 return Err(SparseError::IndexOutOfBounds {
298 index: (rows[i], cols[i]),
299 shape,
300 });
301 }
302 all_data.push((rows[i], cols[i], data[i]));
303 }
304
305 if !sorted {
306 all_data.sort_by_key(|&(row, col_, _)| (row, col_));
307 }
308
309 let mut row_counts = vec![0; shape.0];
311 for &(row_, _, _) in &all_data {
312 row_counts[row_] += 1;
313 }
314
315 let mut indptr = Vec::with_capacity(shape.0 + 1);
317 indptr.push(0);
318 let mut cumsum = 0;
319 for &count in &row_counts {
320 cumsum += count;
321 indptr.push(cumsum);
322 }
323
324 let mut indices = Vec::with_capacity(nnz);
326 let mut values = Vec::with_capacity(nnz);
327
328 for (_, col, val) in all_data {
329 indices.push(col);
330 values.push(val);
331 }
332
333 Self::new(
334 Array1::from_vec(values),
335 Array1::from_vec(indices),
336 Array1::from_vec(indptr),
337 shape,
338 )
339 }
340
341 fn check_sorted_indices(indices: &Array1<usize>, indptr: &Array1<usize>) -> bool {
343 for row in 0..indptr.len() - 1 {
344 let start = indptr[row];
345 let end = indptr[row + 1];
346
347 for i in start..end.saturating_sub(1) {
348 if i + 1 < indices.len() && indices[i] > indices[i + 1] {
349 return false;
350 }
351 }
352 }
353 true
354 }
355
356 pub fn get_data(&self) -> &Array1<T> {
358 &self.data
359 }
360
361 pub fn get_indices(&self) -> &Array1<usize> {
363 &self.indices
364 }
365
366 pub fn get_indptr(&self) -> &Array1<usize> {
368 &self.indptr
369 }
370
371 pub fn nrows(&self) -> usize {
373 self.shape.0
374 }
375
376 pub fn ncols(&self) -> usize {
378 self.shape.1
379 }
380
381 pub fn shape(&self) -> (usize, usize) {
383 self.shape
384 }
385}
386
387impl<T> SparseArray<T> for CsrArray<T>
388where
389 T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
390{
391 fn shape(&self) -> (usize, usize) {
392 self.shape
393 }
394
395 fn nnz(&self) -> usize {
396 self.data.len()
397 }
398
399 fn dtype(&self) -> &str {
400 "float" }
402
403 fn to_array(&self) -> Array2<T> {
404 let (rows, cols) = self.shape;
405 let mut result = Array2::zeros((rows, cols));
406
407 for row in 0..rows {
408 let start = self.indptr[row];
409 let end = self.indptr[row + 1];
410
411 for i in start..end {
412 let col = self.indices[i];
413 result[[row, col]] = self.data[i];
414 }
415 }
416
417 result
418 }
419
420 fn toarray(&self) -> Array2<T> {
421 self.to_array()
422 }
423
424 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
425 Ok(Box::new(self.clone()))
428 }
429
430 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
431 Ok(Box::new(self.clone()))
432 }
433
434 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
435 Ok(Box::new(self.clone()))
438 }
439
440 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
441 Ok(Box::new(self.clone()))
444 }
445
446 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
447 Ok(Box::new(self.clone()))
450 }
451
452 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
453 Ok(Box::new(self.clone()))
456 }
457
458 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
459 Ok(Box::new(self.clone()))
462 }
463
464 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
465 let self_array = self.to_array();
468 let other_array = other.to_array();
469
470 if self.shape() != other.shape() {
471 return Err(SparseError::DimensionMismatch {
472 expected: self.shape().0,
473 found: other.shape().0,
474 });
475 }
476
477 let result = &self_array + &other_array;
478
479 let (rows, cols) = self.shape();
481 let mut data = Vec::new();
482 let mut indices = Vec::new();
483 let mut indptr = vec![0];
484
485 for row in 0..rows {
486 for col in 0..cols {
487 let val = result[[row, col]];
488 if val != T::sparse_zero() {
489 data.push(val);
490 indices.push(col);
491 }
492 }
493 indptr.push(data.len());
494 }
495
496 CsrArray::new(
497 Array1::from_vec(data),
498 Array1::from_vec(indices),
499 Array1::from_vec(indptr),
500 self.shape(),
501 )
502 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
503 }
504
505 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
506 let self_array = self.to_array();
508 let other_array = other.to_array();
509
510 if self.shape() != other.shape() {
511 return Err(SparseError::DimensionMismatch {
512 expected: self.shape().0,
513 found: other.shape().0,
514 });
515 }
516
517 let result = &self_array - &other_array;
518
519 let (rows, cols) = self.shape();
521 let mut data = Vec::new();
522 let mut indices = Vec::new();
523 let mut indptr = vec![0];
524
525 for row in 0..rows {
526 for col in 0..cols {
527 let val = result[[row, col]];
528 if val != T::sparse_zero() {
529 data.push(val);
530 indices.push(col);
531 }
532 }
533 indptr.push(data.len());
534 }
535
536 CsrArray::new(
537 Array1::from_vec(data),
538 Array1::from_vec(indices),
539 Array1::from_vec(indptr),
540 self.shape(),
541 )
542 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
543 }
544
545 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
546 let self_array = self.to_array();
549 let other_array = other.to_array();
550
551 if self.shape() != other.shape() {
552 return Err(SparseError::DimensionMismatch {
553 expected: self.shape().0,
554 found: other.shape().0,
555 });
556 }
557
558 let result = &self_array * &other_array;
559
560 let (rows, cols) = self.shape();
562 let mut data = Vec::new();
563 let mut indices = Vec::new();
564 let mut indptr = vec![0];
565
566 for row in 0..rows {
567 for col in 0..cols {
568 let val = result[[row, col]];
569 if val != T::sparse_zero() {
570 data.push(val);
571 indices.push(col);
572 }
573 }
574 indptr.push(data.len());
575 }
576
577 CsrArray::new(
578 Array1::from_vec(data),
579 Array1::from_vec(indices),
580 Array1::from_vec(indptr),
581 self.shape(),
582 )
583 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
584 }
585
586 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
587 let self_array = self.to_array();
589 let other_array = other.to_array();
590
591 if self.shape() != other.shape() {
592 return Err(SparseError::DimensionMismatch {
593 expected: self.shape().0,
594 found: other.shape().0,
595 });
596 }
597
598 let result = &self_array / &other_array;
599
600 let (rows, cols) = self.shape();
602 let mut data = Vec::new();
603 let mut indices = Vec::new();
604 let mut indptr = vec![0];
605
606 for row in 0..rows {
607 for col in 0..cols {
608 let val = result[[row, col]];
609 if val != T::sparse_zero() {
610 data.push(val);
611 indices.push(col);
612 }
613 }
614 indptr.push(data.len());
615 }
616
617 CsrArray::new(
618 Array1::from_vec(data),
619 Array1::from_vec(indices),
620 Array1::from_vec(indptr),
621 self.shape(),
622 )
623 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
624 }
625
626 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
627 let (m, n) = self.shape();
630 let (p, q) = other.shape();
631
632 if n != p {
633 return Err(SparseError::DimensionMismatch {
634 expected: n,
635 found: p,
636 });
637 }
638
639 let mut result = Array2::zeros((m, q));
640 let other_array = other.to_array();
641
642 for row in 0..m {
643 let start = self.indptr[row];
644 let end = self.indptr[row + 1];
645
646 for j in 0..q {
647 let mut sum = T::sparse_zero();
648 for idx in start..end {
649 let col = self.indices[idx];
650 sum = sum + self.data[idx] * other_array[[col, j]];
651 }
652 if sum != T::sparse_zero() {
653 result[[row, j]] = sum;
654 }
655 }
656 }
657
658 let mut data = Vec::new();
660 let mut indices = Vec::new();
661 let mut indptr = vec![0];
662
663 for row in 0..m {
664 for col in 0..q {
665 let val = result[[row, col]];
666 if val != T::sparse_zero() {
667 data.push(val);
668 indices.push(col);
669 }
670 }
671 indptr.push(data.len());
672 }
673
674 CsrArray::new(
675 Array1::from_vec(data),
676 Array1::from_vec(indices),
677 Array1::from_vec(indptr),
678 (m, q),
679 )
680 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
681 }
682
683 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
684 let (m, n) = self.shape();
685 if n != other.len() {
686 return Err(SparseError::DimensionMismatch {
687 expected: n,
688 found: other.len(),
689 });
690 }
691
692 let mut result = Array1::zeros(m);
693
694 for row in 0..m {
695 let start = self.indptr[row];
696 let end = self.indptr[row + 1];
697
698 let mut sum = T::sparse_zero();
699 for idx in start..end {
700 let col = self.indices[idx];
701 sum = sum + self.data[idx] * other[col];
702 }
703 result[row] = sum;
704 }
705
706 Ok(result)
707 }
708
709 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
710 let (rows, cols) = self.shape();
714 let mut row_indices = Vec::with_capacity(self.nnz());
715 let mut col_indices = Vec::with_capacity(self.nnz());
716 let mut values = Vec::with_capacity(self.nnz());
717
718 for row in 0..rows {
719 let start = self.indptr[row];
720 let end = self.indptr[row + 1];
721
722 for idx in start..end {
723 let col = self.indices[idx];
724 row_indices.push(col); col_indices.push(row);
726 values.push(self.data[idx]);
727 }
728 }
729
730 CsrArray::from_triplets(
732 &row_indices,
733 &col_indices,
734 &values,
735 (cols, rows), false,
737 )
738 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
739 }
740
741 fn copy(&self) -> Box<dyn SparseArray<T>> {
742 Box::new(self.clone())
743 }
744
745 fn get(&self, i: usize, j: usize) -> T {
746 if i >= self.shape.0 || j >= self.shape.1 {
747 return T::sparse_zero();
748 }
749
750 let start = self.indptr[i];
751 let end = self.indptr[i + 1];
752
753 for idx in start..end {
754 if self.indices[idx] == j {
755 return self.data[idx];
756 }
757 if self.has_sorted_indices && self.indices[idx] > j {
759 break;
760 }
761 }
762
763 T::sparse_zero()
764 }
765
766 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
767 if i >= self.shape.0 || j >= self.shape.1 {
770 return Err(SparseError::IndexOutOfBounds {
771 index: (i, j),
772 shape: self.shape,
773 });
774 }
775
776 let start = self.indptr[i];
777 let end = self.indptr[i + 1];
778
779 for idx in start..end {
781 if self.indices[idx] == j {
782 self.data[idx] = value;
784 return Ok(());
785 }
786 if self.has_sorted_indices && self.indices[idx] > j {
788 return Err(SparseError::NotImplemented(
791 "Inserting new elements in CSR format".to_string(),
792 ));
793 }
794 }
795
796 Err(SparseError::NotImplemented(
799 "Inserting new elements in CSR format".to_string(),
800 ))
801 }
802
803 fn eliminate_zeros(&mut self) {
804 let mut new_data = Vec::new();
806 let mut new_indices = Vec::new();
807 let mut new_indptr = vec![0];
808
809 let (rows, _) = self.shape();
810
811 for row in 0..rows {
812 let start = self.indptr[row];
813 let end = self.indptr[row + 1];
814
815 for idx in start..end {
816 if !SparseElement::is_zero(&self.data[idx]) {
817 new_data.push(self.data[idx]);
818 new_indices.push(self.indices[idx]);
819 }
820 }
821 new_indptr.push(new_data.len());
822 }
823
824 self.data = Array1::from_vec(new_data);
826 self.indices = Array1::from_vec(new_indices);
827 self.indptr = Array1::from_vec(new_indptr);
828 }
829
830 fn sort_indices(&mut self) {
831 if self.has_sorted_indices {
832 return;
833 }
834
835 let (rows, _) = self.shape();
836
837 for row in 0..rows {
838 let start = self.indptr[row];
839 let end = self.indptr[row + 1];
840
841 if start == end {
842 continue;
843 }
844
845 let mut row_data = Vec::with_capacity(end - start);
847 for idx in start..end {
848 row_data.push((self.indices[idx], self.data[idx]));
849 }
850
851 row_data.sort_by_key(|&(col_, _)| col_);
853
854 for (i, (col, val)) in row_data.into_iter().enumerate() {
856 self.indices[start + i] = col;
857 self.data[start + i] = val;
858 }
859 }
860
861 self.has_sorted_indices = true;
862 }
863
864 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
865 if self.has_sorted_indices {
866 return Box::new(self.clone());
867 }
868
869 let mut sorted = self.clone();
870 sorted.sort_indices();
871 Box::new(sorted)
872 }
873
874 fn has_sorted_indices(&self) -> bool {
875 self.has_sorted_indices
876 }
877
878 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
879 match axis {
880 None => {
881 let mut sum = T::sparse_zero();
883 for &val in self.data.iter() {
884 sum = sum + val;
885 }
886 Ok(SparseSum::Scalar(sum))
887 }
888 Some(0) => {
889 let (_, cols) = self.shape();
891 let mut result = vec![T::sparse_zero(); cols];
892
893 for row in 0..self.shape.0 {
894 let start = self.indptr[row];
895 let end = self.indptr[row + 1];
896
897 for idx in start..end {
898 let col = self.indices[idx];
899 result[col] = result[col] + self.data[idx];
900 }
901 }
902
903 let mut data = Vec::new();
905 let mut indices = Vec::new();
906 let mut indptr = vec![0];
907
908 for (col, &val) in result.iter().enumerate() {
909 if val != T::sparse_zero() {
910 data.push(val);
911 indices.push(col);
912 }
913 }
914 indptr.push(data.len());
915
916 let result_array = CsrArray::new(
917 Array1::from_vec(data),
918 Array1::from_vec(indices),
919 Array1::from_vec(indptr),
920 (1, cols),
921 )?;
922
923 Ok(SparseSum::SparseArray(Box::new(result_array)))
924 }
925 Some(1) => {
926 let mut result = Vec::with_capacity(self.shape.0);
928
929 for row in 0..self.shape.0 {
930 let start = self.indptr[row];
931 let end = self.indptr[row + 1];
932
933 let mut row_sum = T::sparse_zero();
934 for idx in start..end {
935 row_sum = row_sum + self.data[idx];
936 }
937 result.push(row_sum);
938 }
939
940 let mut data = Vec::new();
942 let mut indices = Vec::new();
943 let mut indptr = vec![0];
944
945 for &val in result.iter() {
946 if val != T::sparse_zero() {
947 data.push(val);
948 indices.push(0);
949 indptr.push(data.len());
950 } else {
951 indptr.push(data.len());
952 }
953 }
954
955 let result_array = CsrArray::new(
956 Array1::from_vec(data),
957 Array1::from_vec(indices),
958 Array1::from_vec(indptr),
959 (self.shape.0, 1),
960 )?;
961
962 Ok(SparseSum::SparseArray(Box::new(result_array)))
963 }
964 _ => Err(SparseError::InvalidAxis),
965 }
966 }
967
968 fn max(&self) -> T {
969 if self.data.is_empty() {
970 return T::sparse_zero();
972 }
973
974 let mut max_val = self.data[0];
975 for &val in self.data.iter().skip(1) {
976 if val > max_val {
977 max_val = val;
978 }
979 }
980
981 let zero = T::sparse_zero();
984 if max_val < zero && self.nnz() < self.shape.0 * self.shape.1 {
985 max_val = zero;
986 }
987
988 max_val
989 }
990
991 fn min(&self) -> T {
992 if self.data.is_empty() {
993 return T::sparse_zero();
995 }
996
997 let mut min_val = self.data[0];
998 for &val in self.data.iter().skip(1) {
999 if val < min_val {
1000 min_val = val;
1001 }
1002 }
1003
1004 let zero = T::sparse_zero();
1007 if min_val > zero && self.nnz() < self.shape.0 * self.shape.1 {
1008 min_val = zero;
1009 }
1010
1011 min_val
1012 }
1013
1014 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
1015 let nnz = self.nnz();
1016 let mut rows = Vec::with_capacity(nnz);
1017 let mut cols = Vec::with_capacity(nnz);
1018 let mut values = Vec::with_capacity(nnz);
1019
1020 for row in 0..self.shape.0 {
1021 let start = self.indptr[row];
1022 let end = self.indptr[row + 1];
1023
1024 for idx in start..end {
1025 let col = self.indices[idx];
1026 rows.push(row);
1027 cols.push(col);
1028 values.push(self.data[idx]);
1029 }
1030 }
1031
1032 (
1033 Array1::from_vec(rows),
1034 Array1::from_vec(cols),
1035 Array1::from_vec(values),
1036 )
1037 }
1038
1039 fn slice(
1040 &self,
1041 row_range: (usize, usize),
1042 col_range: (usize, usize),
1043 ) -> SparseResult<Box<dyn SparseArray<T>>> {
1044 let (start_row, end_row) = row_range;
1045 let (start_col, end_col) = col_range;
1046
1047 if start_row >= self.shape.0
1048 || end_row > self.shape.0
1049 || start_col >= self.shape.1
1050 || end_col > self.shape.1
1051 {
1052 return Err(SparseError::InvalidSliceRange);
1053 }
1054
1055 if start_row >= end_row || start_col >= end_col {
1056 return Err(SparseError::InvalidSliceRange);
1057 }
1058
1059 let mut data = Vec::new();
1060 let mut indices = Vec::new();
1061 let mut indptr = vec![0];
1062
1063 for row in start_row..end_row {
1064 let start = self.indptr[row];
1065 let end = self.indptr[row + 1];
1066
1067 for idx in start..end {
1068 let col = self.indices[idx];
1069 if col >= start_col && col < end_col {
1070 data.push(self.data[idx]);
1071 indices.push(col - start_col);
1072 }
1073 }
1074 indptr.push(data.len());
1075 }
1076
1077 CsrArray::new(
1078 Array1::from_vec(data),
1079 Array1::from_vec(indices),
1080 Array1::from_vec(indptr),
1081 (end_row - start_row, end_col - start_col),
1082 )
1083 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
1084 }
1085
1086 fn as_any(&self) -> &dyn std::any::Any {
1087 self
1088 }
1089
1090 fn get_indptr(&self) -> Option<&Array1<usize>> {
1091 Some(&self.indptr)
1092 }
1093
1094 fn indptr(&self) -> Option<&Array1<usize>> {
1095 Some(&self.indptr)
1096 }
1097}
1098
1099impl<T> fmt::Debug for CsrArray<T>
1100where
1101 T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
1102{
1103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1104 write!(
1105 f,
1106 "CsrArray<{}x{}, nnz={}>",
1107 self.shape.0,
1108 self.shape.1,
1109 self.nnz()
1110 )
1111 }
1112}
1113
1114#[cfg(test)]
1115mod tests {
1116 use super::*;
1117 use approx::assert_relative_eq;
1118
1119 #[test]
1120 fn test_csr_array_construction() {
1121 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1122 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1123 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1124 let shape = (3, 3);
1125
1126 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1127
1128 assert_eq!(csr.shape(), (3, 3));
1129 assert_eq!(csr.nnz(), 5);
1130 assert_eq!(csr.get(0, 0), 1.0);
1131 assert_eq!(csr.get(0, 2), 2.0);
1132 assert_eq!(csr.get(1, 1), 3.0);
1133 assert_eq!(csr.get(2, 0), 4.0);
1134 assert_eq!(csr.get(2, 2), 5.0);
1135 assert_eq!(csr.get(0, 1), 0.0);
1136 }
1137
1138 #[test]
1139 fn test_csr_from_triplets() {
1140 let rows = vec![0, 0, 1, 2, 2];
1141 let cols = vec![0, 2, 1, 0, 2];
1142 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1143 let shape = (3, 3);
1144
1145 let csr = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
1146
1147 assert_eq!(csr.shape(), (3, 3));
1148 assert_eq!(csr.nnz(), 5);
1149 assert_eq!(csr.get(0, 0), 1.0);
1150 assert_eq!(csr.get(0, 2), 2.0);
1151 assert_eq!(csr.get(1, 1), 3.0);
1152 assert_eq!(csr.get(2, 0), 4.0);
1153 assert_eq!(csr.get(2, 2), 5.0);
1154 assert_eq!(csr.get(0, 1), 0.0);
1155 }
1156
1157 #[test]
1158 fn test_csr_array_to_array() {
1159 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1160 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1161 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1162 let shape = (3, 3);
1163
1164 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1165 let dense = csr.to_array();
1166
1167 assert_eq!(dense.shape(), &[3, 3]);
1168 assert_eq!(dense[[0, 0]], 1.0);
1169 assert_eq!(dense[[0, 1]], 0.0);
1170 assert_eq!(dense[[0, 2]], 2.0);
1171 assert_eq!(dense[[1, 0]], 0.0);
1172 assert_eq!(dense[[1, 1]], 3.0);
1173 assert_eq!(dense[[1, 2]], 0.0);
1174 assert_eq!(dense[[2, 0]], 4.0);
1175 assert_eq!(dense[[2, 1]], 0.0);
1176 assert_eq!(dense[[2, 2]], 5.0);
1177 }
1178
1179 #[test]
1180 fn test_csr_array_dot_vector() {
1181 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1182 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1183 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1184 let shape = (3, 3);
1185
1186 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1187 let vec = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1188
1189 let result = csr.dot_vector(&vec.view()).unwrap();
1190
1191 assert_eq!(result.len(), 3);
1193 assert_relative_eq!(result[0], 7.0);
1194 assert_relative_eq!(result[1], 6.0);
1195 assert_relative_eq!(result[2], 19.0);
1196 }
1197
1198 #[test]
1199 fn test_csr_array_sum() {
1200 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1201 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1202 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1203 let shape = (3, 3);
1204
1205 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1206
1207 if let SparseSum::Scalar(sum) = csr.sum(None).unwrap() {
1209 assert_relative_eq!(sum, 15.0);
1210 } else {
1211 panic!("Expected scalar sum");
1212 }
1213
1214 if let SparseSum::SparseArray(row_sum) = csr.sum(Some(0)).unwrap() {
1216 let row_sum_array = row_sum.to_array();
1217 assert_eq!(row_sum_array.shape(), &[1, 3]);
1218 assert_relative_eq!(row_sum_array[[0, 0]], 5.0);
1219 assert_relative_eq!(row_sum_array[[0, 1]], 3.0);
1220 assert_relative_eq!(row_sum_array[[0, 2]], 7.0);
1221 } else {
1222 panic!("Expected sparse array sum");
1223 }
1224
1225 if let SparseSum::SparseArray(col_sum) = csr.sum(Some(1)).unwrap() {
1227 let col_sum_array = col_sum.to_array();
1228 assert_eq!(col_sum_array.shape(), &[3, 1]);
1229 assert_relative_eq!(col_sum_array[[0, 0]], 3.0);
1230 assert_relative_eq!(col_sum_array[[1, 0]], 3.0);
1231 assert_relative_eq!(col_sum_array[[2, 0]], 9.0);
1232 } else {
1233 panic!("Expected sparse array sum");
1234 }
1235 }
1236}