1use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12use crate::sparray::{SparseArray, SparseSum};
13
14#[derive(Clone)]
106pub struct CsrArray<T>
107where
108 T: Float
109 + Add<Output = T>
110 + Sub<Output = T>
111 + Mul<Output = T>
112 + Div<Output = T>
113 + Debug
114 + Copy
115 + 'static,
116{
117 data: Array1<T>,
119 indices: Array1<usize>,
121 indptr: Array1<usize>,
123 shape: (usize, usize),
125 has_sorted_indices: bool,
127}
128
129impl<T> CsrArray<T>
130where
131 T: Float
132 + Add<Output = T>
133 + Sub<Output = T>
134 + Mul<Output = T>
135 + Div<Output = T>
136 + Debug
137 + Copy
138 + 'static,
139{
140 pub fn new(
154 data: Array1<T>,
155 indices: Array1<usize>,
156 indptr: Array1<usize>,
157 shape: (usize, usize),
158 ) -> SparseResult<Self> {
159 if data.len() != indices.len() {
161 return Err(SparseError::InconsistentData {
162 reason: "data and indices must have the same length".to_string(),
163 });
164 }
165
166 if indptr.len() != shape.0 + 1 {
167 return Err(SparseError::InconsistentData {
168 reason: format!(
169 "indptr length ({}) must be one more than the number of rows ({})",
170 indptr.len(),
171 shape.0
172 ),
173 });
174 }
175
176 if let Some(&max_idx) = indices.iter().max() {
177 if max_idx >= shape.1 {
178 return Err(SparseError::IndexOutOfBounds {
179 index: (0, max_idx),
180 shape,
181 });
182 }
183 }
184
185 if let Some((&last, &first)) = indptr.iter().next_back().zip(indptr.iter().next()) {
186 if first != 0 {
187 return Err(SparseError::InconsistentData {
188 reason: "first element of indptr must be 0".to_string(),
189 });
190 }
191
192 if last != data.len() {
193 return Err(SparseError::InconsistentData {
194 reason: format!(
195 "last element of indptr ({}) must equal data length ({})",
196 last,
197 data.len()
198 ),
199 });
200 }
201 }
202
203 let has_sorted_indices = Self::check_sorted_indices(&indices, &indptr);
204
205 Ok(Self {
206 data,
207 indices,
208 indptr,
209 shape,
210 has_sorted_indices,
211 })
212 }
213
214 pub fn from_triplets(
288 rows: &[usize],
289 cols: &[usize],
290 data: &[T],
291 shape: (usize, usize),
292 sorted: bool,
293 ) -> SparseResult<Self> {
294 if rows.len() != cols.len() || rows.len() != data.len() {
295 return Err(SparseError::InconsistentData {
296 reason: "rows, cols, and data must have the same length".to_string(),
297 });
298 }
299
300 if rows.is_empty() {
301 let indptr = Array1::zeros(shape.0 + 1);
303 return Self::new(Array1::zeros(0), Array1::zeros(0), indptr, shape);
304 }
305
306 let nnz = rows.len();
307 let mut all_data: Vec<(usize, usize, T)> = Vec::with_capacity(nnz);
308
309 for i in 0..nnz {
310 if rows[i] >= shape.0 || cols[i] >= shape.1 {
311 return Err(SparseError::IndexOutOfBounds {
312 index: (rows[i], cols[i]),
313 shape,
314 });
315 }
316 all_data.push((rows[i], cols[i], data[i]));
317 }
318
319 if !sorted {
320 all_data.sort_by_key(|&(row, col_, _)| (row, col_));
321 }
322
323 let mut row_counts = vec![0; shape.0];
325 for &(row_, _, _) in &all_data {
326 row_counts[row_] += 1;
327 }
328
329 let mut indptr = Vec::with_capacity(shape.0 + 1);
331 indptr.push(0);
332 let mut cumsum = 0;
333 for &count in &row_counts {
334 cumsum += count;
335 indptr.push(cumsum);
336 }
337
338 let mut indices = Vec::with_capacity(nnz);
340 let mut values = Vec::with_capacity(nnz);
341
342 for (_, col, val) in all_data {
343 indices.push(col);
344 values.push(val);
345 }
346
347 Self::new(
348 Array1::from_vec(values),
349 Array1::from_vec(indices),
350 Array1::from_vec(indptr),
351 shape,
352 )
353 }
354
355 fn check_sorted_indices(indices: &Array1<usize>, indptr: &Array1<usize>) -> bool {
357 for row in 0..indptr.len() - 1 {
358 let start = indptr[row];
359 let end = indptr[row + 1];
360
361 for i in start..end.saturating_sub(1) {
362 if i + 1 < indices.len() && indices[i] > indices[i + 1] {
363 return false;
364 }
365 }
366 }
367 true
368 }
369
370 pub fn get_data(&self) -> &Array1<T> {
372 &self.data
373 }
374
375 pub fn get_indices(&self) -> &Array1<usize> {
377 &self.indices
378 }
379
380 pub fn get_indptr(&self) -> &Array1<usize> {
382 &self.indptr
383 }
384
385 pub fn nrows(&self) -> usize {
387 self.shape.0
388 }
389
390 pub fn ncols(&self) -> usize {
392 self.shape.1
393 }
394
395 pub fn shape(&self) -> (usize, usize) {
397 self.shape
398 }
399}
400
401impl<T> SparseArray<T> for CsrArray<T>
402where
403 T: Float
404 + Add<Output = T>
405 + Sub<Output = T>
406 + Mul<Output = T>
407 + Div<Output = T>
408 + Debug
409 + Copy
410 + 'static,
411{
412 fn shape(&self) -> (usize, usize) {
413 self.shape
414 }
415
416 fn nnz(&self) -> usize {
417 self.data.len()
418 }
419
420 fn dtype(&self) -> &str {
421 "float" }
423
424 fn to_array(&self) -> Array2<T> {
425 let (rows, cols) = self.shape;
426 let mut result = Array2::zeros((rows, cols));
427
428 for row in 0..rows {
429 let start = self.indptr[row];
430 let end = self.indptr[row + 1];
431
432 for i in start..end {
433 let col = self.indices[i];
434 result[[row, col]] = self.data[i];
435 }
436 }
437
438 result
439 }
440
441 fn toarray(&self) -> Array2<T> {
442 self.to_array()
443 }
444
445 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
446 Ok(Box::new(self.clone()))
449 }
450
451 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
452 Ok(Box::new(self.clone()))
453 }
454
455 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
456 Ok(Box::new(self.clone()))
459 }
460
461 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
462 Ok(Box::new(self.clone()))
465 }
466
467 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
468 Ok(Box::new(self.clone()))
471 }
472
473 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
474 Ok(Box::new(self.clone()))
477 }
478
479 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
480 Ok(Box::new(self.clone()))
483 }
484
485 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
486 let self_array = self.to_array();
489 let other_array = other.to_array();
490
491 if self.shape() != other.shape() {
492 return Err(SparseError::DimensionMismatch {
493 expected: self.shape().0,
494 found: other.shape().0,
495 });
496 }
497
498 let result = &self_array + &other_array;
499
500 let (rows, cols) = self.shape();
502 let mut data = Vec::new();
503 let mut indices = Vec::new();
504 let mut indptr = vec![0];
505
506 for row in 0..rows {
507 for col in 0..cols {
508 let val = result[[row, col]];
509 if !val.is_zero() {
510 data.push(val);
511 indices.push(col);
512 }
513 }
514 indptr.push(data.len());
515 }
516
517 CsrArray::new(
518 Array1::from_vec(data),
519 Array1::from_vec(indices),
520 Array1::from_vec(indptr),
521 self.shape(),
522 )
523 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
524 }
525
526 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
527 let self_array = self.to_array();
529 let other_array = other.to_array();
530
531 if self.shape() != other.shape() {
532 return Err(SparseError::DimensionMismatch {
533 expected: self.shape().0,
534 found: other.shape().0,
535 });
536 }
537
538 let result = &self_array - &other_array;
539
540 let (rows, cols) = self.shape();
542 let mut data = Vec::new();
543 let mut indices = Vec::new();
544 let mut indptr = vec![0];
545
546 for row in 0..rows {
547 for col in 0..cols {
548 let val = result[[row, col]];
549 if !val.is_zero() {
550 data.push(val);
551 indices.push(col);
552 }
553 }
554 indptr.push(data.len());
555 }
556
557 CsrArray::new(
558 Array1::from_vec(data),
559 Array1::from_vec(indices),
560 Array1::from_vec(indptr),
561 self.shape(),
562 )
563 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
564 }
565
566 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
567 let self_array = self.to_array();
570 let other_array = other.to_array();
571
572 if self.shape() != other.shape() {
573 return Err(SparseError::DimensionMismatch {
574 expected: self.shape().0,
575 found: other.shape().0,
576 });
577 }
578
579 let result = &self_array * &other_array;
580
581 let (rows, cols) = self.shape();
583 let mut data = Vec::new();
584 let mut indices = Vec::new();
585 let mut indptr = vec![0];
586
587 for row in 0..rows {
588 for col in 0..cols {
589 let val = result[[row, col]];
590 if !val.is_zero() {
591 data.push(val);
592 indices.push(col);
593 }
594 }
595 indptr.push(data.len());
596 }
597
598 CsrArray::new(
599 Array1::from_vec(data),
600 Array1::from_vec(indices),
601 Array1::from_vec(indptr),
602 self.shape(),
603 )
604 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
605 }
606
607 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
608 let self_array = self.to_array();
610 let other_array = other.to_array();
611
612 if self.shape() != other.shape() {
613 return Err(SparseError::DimensionMismatch {
614 expected: self.shape().0,
615 found: other.shape().0,
616 });
617 }
618
619 let result = &self_array / &other_array;
620
621 let (rows, cols) = self.shape();
623 let mut data = Vec::new();
624 let mut indices = Vec::new();
625 let mut indptr = vec![0];
626
627 for row in 0..rows {
628 for col in 0..cols {
629 let val = result[[row, col]];
630 if !val.is_zero() {
631 data.push(val);
632 indices.push(col);
633 }
634 }
635 indptr.push(data.len());
636 }
637
638 CsrArray::new(
639 Array1::from_vec(data),
640 Array1::from_vec(indices),
641 Array1::from_vec(indptr),
642 self.shape(),
643 )
644 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
645 }
646
647 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
648 let (m, n) = self.shape();
651 let (p, q) = other.shape();
652
653 if n != p {
654 return Err(SparseError::DimensionMismatch {
655 expected: n,
656 found: p,
657 });
658 }
659
660 let mut result = Array2::zeros((m, q));
661 let other_array = other.to_array();
662
663 for row in 0..m {
664 let start = self.indptr[row];
665 let end = self.indptr[row + 1];
666
667 for j in 0..q {
668 let mut sum = T::zero();
669 for idx in start..end {
670 let col = self.indices[idx];
671 sum = sum + self.data[idx] * other_array[[col, j]];
672 }
673 if !sum.is_zero() {
674 result[[row, j]] = sum;
675 }
676 }
677 }
678
679 let mut data = Vec::new();
681 let mut indices = Vec::new();
682 let mut indptr = vec![0];
683
684 for row in 0..m {
685 for col in 0..q {
686 let val = result[[row, col]];
687 if !val.is_zero() {
688 data.push(val);
689 indices.push(col);
690 }
691 }
692 indptr.push(data.len());
693 }
694
695 CsrArray::new(
696 Array1::from_vec(data),
697 Array1::from_vec(indices),
698 Array1::from_vec(indptr),
699 (m, q),
700 )
701 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
702 }
703
704 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
705 let (m, n) = self.shape();
706 if n != other.len() {
707 return Err(SparseError::DimensionMismatch {
708 expected: n,
709 found: other.len(),
710 });
711 }
712
713 let mut result = Array1::zeros(m);
714
715 for row in 0..m {
716 let start = self.indptr[row];
717 let end = self.indptr[row + 1];
718
719 let mut sum = T::zero();
720 for idx in start..end {
721 let col = self.indices[idx];
722 sum = sum + self.data[idx] * other[col];
723 }
724 result[row] = sum;
725 }
726
727 Ok(result)
728 }
729
730 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
731 let (rows, cols) = self.shape();
735 let mut row_indices = Vec::with_capacity(self.nnz());
736 let mut col_indices = Vec::with_capacity(self.nnz());
737 let mut values = Vec::with_capacity(self.nnz());
738
739 for row in 0..rows {
740 let start = self.indptr[row];
741 let end = self.indptr[row + 1];
742
743 for idx in start..end {
744 let col = self.indices[idx];
745 row_indices.push(col); col_indices.push(row);
747 values.push(self.data[idx]);
748 }
749 }
750
751 CsrArray::from_triplets(
753 &row_indices,
754 &col_indices,
755 &values,
756 (cols, rows), false,
758 )
759 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
760 }
761
762 fn copy(&self) -> Box<dyn SparseArray<T>> {
763 Box::new(self.clone())
764 }
765
766 fn get(&self, i: usize, j: usize) -> T {
767 if i >= self.shape.0 || j >= self.shape.1 {
768 return T::zero();
769 }
770
771 let start = self.indptr[i];
772 let end = self.indptr[i + 1];
773
774 for idx in start..end {
775 if self.indices[idx] == j {
776 return self.data[idx];
777 }
778 if self.has_sorted_indices && self.indices[idx] > j {
780 break;
781 }
782 }
783
784 T::zero()
785 }
786
787 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
788 if i >= self.shape.0 || j >= self.shape.1 {
791 return Err(SparseError::IndexOutOfBounds {
792 index: (i, j),
793 shape: self.shape,
794 });
795 }
796
797 let start = self.indptr[i];
798 let end = self.indptr[i + 1];
799
800 for idx in start..end {
802 if self.indices[idx] == j {
803 self.data[idx] = value;
805 return Ok(());
806 }
807 if self.has_sorted_indices && self.indices[idx] > j {
809 return Err(SparseError::NotImplemented(
812 "Inserting new elements in CSR format".to_string(),
813 ));
814 }
815 }
816
817 Err(SparseError::NotImplemented(
820 "Inserting new elements in CSR format".to_string(),
821 ))
822 }
823
824 fn eliminate_zeros(&mut self) {
825 let mut new_data = Vec::new();
827 let mut new_indices = Vec::new();
828 let mut new_indptr = vec![0];
829
830 let (rows, _) = self.shape();
831
832 for row in 0..rows {
833 let start = self.indptr[row];
834 let end = self.indptr[row + 1];
835
836 for idx in start..end {
837 if !self.data[idx].is_zero() {
838 new_data.push(self.data[idx]);
839 new_indices.push(self.indices[idx]);
840 }
841 }
842 new_indptr.push(new_data.len());
843 }
844
845 self.data = Array1::from_vec(new_data);
847 self.indices = Array1::from_vec(new_indices);
848 self.indptr = Array1::from_vec(new_indptr);
849 }
850
851 fn sort_indices(&mut self) {
852 if self.has_sorted_indices {
853 return;
854 }
855
856 let (rows, _) = self.shape();
857
858 for row in 0..rows {
859 let start = self.indptr[row];
860 let end = self.indptr[row + 1];
861
862 if start == end {
863 continue;
864 }
865
866 let mut row_data = Vec::with_capacity(end - start);
868 for idx in start..end {
869 row_data.push((self.indices[idx], self.data[idx]));
870 }
871
872 row_data.sort_by_key(|&(col_, _)| col_);
874
875 for (i, (col, val)) in row_data.into_iter().enumerate() {
877 self.indices[start + i] = col;
878 self.data[start + i] = val;
879 }
880 }
881
882 self.has_sorted_indices = true;
883 }
884
885 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
886 if self.has_sorted_indices {
887 return Box::new(self.clone());
888 }
889
890 let mut sorted = self.clone();
891 sorted.sort_indices();
892 Box::new(sorted)
893 }
894
895 fn has_sorted_indices(&self) -> bool {
896 self.has_sorted_indices
897 }
898
899 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
900 match axis {
901 None => {
902 let mut sum = T::zero();
904 for &val in self.data.iter() {
905 sum = sum + val;
906 }
907 Ok(SparseSum::Scalar(sum))
908 }
909 Some(0) => {
910 let (_, cols) = self.shape();
912 let mut result = vec![T::zero(); cols];
913
914 for row in 0..self.shape.0 {
915 let start = self.indptr[row];
916 let end = self.indptr[row + 1];
917
918 for idx in start..end {
919 let col = self.indices[idx];
920 result[col] = result[col] + self.data[idx];
921 }
922 }
923
924 let mut data = Vec::new();
926 let mut indices = Vec::new();
927 let mut indptr = vec![0];
928
929 for (col, &val) in result.iter().enumerate() {
930 if !val.is_zero() {
931 data.push(val);
932 indices.push(col);
933 }
934 }
935 indptr.push(data.len());
936
937 let result_array = CsrArray::new(
938 Array1::from_vec(data),
939 Array1::from_vec(indices),
940 Array1::from_vec(indptr),
941 (1, cols),
942 )?;
943
944 Ok(SparseSum::SparseArray(Box::new(result_array)))
945 }
946 Some(1) => {
947 let mut result = Vec::with_capacity(self.shape.0);
949
950 for row in 0..self.shape.0 {
951 let start = self.indptr[row];
952 let end = self.indptr[row + 1];
953
954 let mut row_sum = T::zero();
955 for idx in start..end {
956 row_sum = row_sum + self.data[idx];
957 }
958 result.push(row_sum);
959 }
960
961 let mut data = Vec::new();
963 let mut indices = Vec::new();
964 let mut indptr = vec![0];
965
966 for &val in result.iter() {
967 if !val.is_zero() {
968 data.push(val);
969 indices.push(0);
970 indptr.push(data.len());
971 } else {
972 indptr.push(data.len());
973 }
974 }
975
976 let result_array = CsrArray::new(
977 Array1::from_vec(data),
978 Array1::from_vec(indices),
979 Array1::from_vec(indptr),
980 (self.shape.0, 1),
981 )?;
982
983 Ok(SparseSum::SparseArray(Box::new(result_array)))
984 }
985 _ => Err(SparseError::InvalidAxis),
986 }
987 }
988
989 fn max(&self) -> T {
990 if self.data.is_empty() {
991 return T::neg_infinity();
992 }
993
994 let mut max_val = self.data[0];
995 for &val in self.data.iter().skip(1) {
996 if val > max_val {
997 max_val = val;
998 }
999 }
1000
1001 if max_val < T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
1003 max_val = T::zero();
1004 }
1005
1006 max_val
1007 }
1008
1009 fn min(&self) -> T {
1010 if self.data.is_empty() {
1011 return T::infinity();
1012 }
1013
1014 let mut min_val = self.data[0];
1015 for &val in self.data.iter().skip(1) {
1016 if val < min_val {
1017 min_val = val;
1018 }
1019 }
1020
1021 if min_val > T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
1023 min_val = T::zero();
1024 }
1025
1026 min_val
1027 }
1028
1029 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
1030 let nnz = self.nnz();
1031 let mut rows = Vec::with_capacity(nnz);
1032 let mut cols = Vec::with_capacity(nnz);
1033 let mut values = Vec::with_capacity(nnz);
1034
1035 for row in 0..self.shape.0 {
1036 let start = self.indptr[row];
1037 let end = self.indptr[row + 1];
1038
1039 for idx in start..end {
1040 let col = self.indices[idx];
1041 rows.push(row);
1042 cols.push(col);
1043 values.push(self.data[idx]);
1044 }
1045 }
1046
1047 (
1048 Array1::from_vec(rows),
1049 Array1::from_vec(cols),
1050 Array1::from_vec(values),
1051 )
1052 }
1053
1054 fn slice(
1055 &self,
1056 row_range: (usize, usize),
1057 col_range: (usize, usize),
1058 ) -> SparseResult<Box<dyn SparseArray<T>>> {
1059 let (start_row, end_row) = row_range;
1060 let (start_col, end_col) = col_range;
1061
1062 if start_row >= self.shape.0
1063 || end_row > self.shape.0
1064 || start_col >= self.shape.1
1065 || end_col > self.shape.1
1066 {
1067 return Err(SparseError::InvalidSliceRange);
1068 }
1069
1070 if start_row >= end_row || start_col >= end_col {
1071 return Err(SparseError::InvalidSliceRange);
1072 }
1073
1074 let mut data = Vec::new();
1075 let mut indices = Vec::new();
1076 let mut indptr = vec![0];
1077
1078 for row in start_row..end_row {
1079 let start = self.indptr[row];
1080 let end = self.indptr[row + 1];
1081
1082 for idx in start..end {
1083 let col = self.indices[idx];
1084 if col >= start_col && col < end_col {
1085 data.push(self.data[idx]);
1086 indices.push(col - start_col);
1087 }
1088 }
1089 indptr.push(data.len());
1090 }
1091
1092 CsrArray::new(
1093 Array1::from_vec(data),
1094 Array1::from_vec(indices),
1095 Array1::from_vec(indptr),
1096 (end_row - start_row, end_col - start_col),
1097 )
1098 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
1099 }
1100
1101 fn as_any(&self) -> &dyn std::any::Any {
1102 self
1103 }
1104
1105 fn get_indptr(&self) -> Option<&Array1<usize>> {
1106 Some(&self.indptr)
1107 }
1108
1109 fn indptr(&self) -> Option<&Array1<usize>> {
1110 Some(&self.indptr)
1111 }
1112}
1113
1114impl<T> fmt::Debug for CsrArray<T>
1115where
1116 T: Float
1117 + Add<Output = T>
1118 + Sub<Output = T>
1119 + Mul<Output = T>
1120 + Div<Output = T>
1121 + Debug
1122 + Copy
1123 + 'static,
1124{
1125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1126 write!(
1127 f,
1128 "CsrArray<{}x{}, nnz={}>",
1129 self.shape.0,
1130 self.shape.1,
1131 self.nnz()
1132 )
1133 }
1134}
1135
1136#[cfg(test)]
1137mod tests {
1138 use super::*;
1139 use approx::assert_relative_eq;
1140
1141 #[test]
1142 fn test_csr_array_construction() {
1143 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1144 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1145 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1146 let shape = (3, 3);
1147
1148 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1149
1150 assert_eq!(csr.shape(), (3, 3));
1151 assert_eq!(csr.nnz(), 5);
1152 assert_eq!(csr.get(0, 0), 1.0);
1153 assert_eq!(csr.get(0, 2), 2.0);
1154 assert_eq!(csr.get(1, 1), 3.0);
1155 assert_eq!(csr.get(2, 0), 4.0);
1156 assert_eq!(csr.get(2, 2), 5.0);
1157 assert_eq!(csr.get(0, 1), 0.0);
1158 }
1159
1160 #[test]
1161 fn test_csr_from_triplets() {
1162 let rows = vec![0, 0, 1, 2, 2];
1163 let cols = vec![0, 2, 1, 0, 2];
1164 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1165 let shape = (3, 3);
1166
1167 let csr = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
1168
1169 assert_eq!(csr.shape(), (3, 3));
1170 assert_eq!(csr.nnz(), 5);
1171 assert_eq!(csr.get(0, 0), 1.0);
1172 assert_eq!(csr.get(0, 2), 2.0);
1173 assert_eq!(csr.get(1, 1), 3.0);
1174 assert_eq!(csr.get(2, 0), 4.0);
1175 assert_eq!(csr.get(2, 2), 5.0);
1176 assert_eq!(csr.get(0, 1), 0.0);
1177 }
1178
1179 #[test]
1180 fn test_csr_array_to_array() {
1181 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1182 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1183 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1184 let shape = (3, 3);
1185
1186 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1187 let dense = csr.to_array();
1188
1189 assert_eq!(dense.shape(), &[3, 3]);
1190 assert_eq!(dense[[0, 0]], 1.0);
1191 assert_eq!(dense[[0, 1]], 0.0);
1192 assert_eq!(dense[[0, 2]], 2.0);
1193 assert_eq!(dense[[1, 0]], 0.0);
1194 assert_eq!(dense[[1, 1]], 3.0);
1195 assert_eq!(dense[[1, 2]], 0.0);
1196 assert_eq!(dense[[2, 0]], 4.0);
1197 assert_eq!(dense[[2, 1]], 0.0);
1198 assert_eq!(dense[[2, 2]], 5.0);
1199 }
1200
1201 #[test]
1202 fn test_csr_array_dot_vector() {
1203 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1204 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1205 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1206 let shape = (3, 3);
1207
1208 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1209 let vec = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1210
1211 let result = csr.dot_vector(&vec.view()).unwrap();
1212
1213 assert_eq!(result.len(), 3);
1215 assert_relative_eq!(result[0], 7.0);
1216 assert_relative_eq!(result[1], 6.0);
1217 assert_relative_eq!(result[2], 19.0);
1218 }
1219
1220 #[test]
1221 fn test_csr_array_sum() {
1222 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1223 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1224 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1225 let shape = (3, 3);
1226
1227 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1228
1229 if let SparseSum::Scalar(sum) = csr.sum(None).unwrap() {
1231 assert_relative_eq!(sum, 15.0);
1232 } else {
1233 panic!("Expected scalar sum");
1234 }
1235
1236 if let SparseSum::SparseArray(row_sum) = csr.sum(Some(0)).unwrap() {
1238 let row_sum_array = row_sum.to_array();
1239 assert_eq!(row_sum_array.shape(), &[1, 3]);
1240 assert_relative_eq!(row_sum_array[[0, 0]], 5.0);
1241 assert_relative_eq!(row_sum_array[[0, 1]], 3.0);
1242 assert_relative_eq!(row_sum_array[[0, 2]], 7.0);
1243 } else {
1244 panic!("Expected sparse array sum");
1245 }
1246
1247 if let SparseSum::SparseArray(col_sum) = csr.sum(Some(1)).unwrap() {
1249 let col_sum_array = col_sum.to_array();
1250 assert_eq!(col_sum_array.shape(), &[3, 1]);
1251 assert_relative_eq!(col_sum_array[[0, 0]], 3.0);
1252 assert_relative_eq!(col_sum_array[[1, 0]], 3.0);
1253 assert_relative_eq!(col_sum_array[[2, 0]], 9.0);
1254 } else {
1255 panic!("Expected sparse array sum");
1256 }
1257 }
1258}