1use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12use crate::sparray::{SparseArray, SparseSum};
13
14#[derive(Clone)]
103pub struct CsrArray<T>
104where
105 T: Float
106 + Add<Output = T>
107 + Sub<Output = T>
108 + Mul<Output = T>
109 + Div<Output = T>
110 + Debug
111 + Copy
112 + 'static,
113{
114 data: Array1<T>,
116 indices: Array1<usize>,
118 indptr: Array1<usize>,
120 shape: (usize, usize),
122 has_sorted_indices: bool,
124}
125
126impl<T> CsrArray<T>
127where
128 T: Float
129 + Add<Output = T>
130 + Sub<Output = T>
131 + Mul<Output = T>
132 + Div<Output = T>
133 + Debug
134 + Copy
135 + 'static,
136{
137 pub fn new(
151 data: Array1<T>,
152 indices: Array1<usize>,
153 indptr: Array1<usize>,
154 shape: (usize, usize),
155 ) -> SparseResult<Self> {
156 if data.len() != indices.len() {
158 return Err(SparseError::InconsistentData {
159 reason: "data and indices must have the same length".to_string(),
160 });
161 }
162
163 if indptr.len() != shape.0 + 1 {
164 return Err(SparseError::InconsistentData {
165 reason: format!(
166 "indptr length ({}) must be one more than the number of rows ({})",
167 indptr.len(),
168 shape.0
169 ),
170 });
171 }
172
173 if let Some(&max_idx) = indices.iter().max() {
174 if max_idx >= shape.1 {
175 return Err(SparseError::IndexOutOfBounds {
176 index: (0, max_idx),
177 shape,
178 });
179 }
180 }
181
182 if let Some((&last, &first)) = indptr.iter().next_back().zip(indptr.iter().next()) {
183 if first != 0 {
184 return Err(SparseError::InconsistentData {
185 reason: "first element of indptr must be 0".to_string(),
186 });
187 }
188
189 if last != data.len() {
190 return Err(SparseError::InconsistentData {
191 reason: format!(
192 "last element of indptr ({}) must equal data length ({})",
193 last,
194 data.len()
195 ),
196 });
197 }
198 }
199
200 let has_sorted_indices = Self::check_sorted_indices(&indices, &indptr);
201
202 Ok(Self {
203 data,
204 indices,
205 indptr,
206 shape,
207 has_sorted_indices,
208 })
209 }
210
211 pub fn from_triplets(
282 rows: &[usize],
283 cols: &[usize],
284 data: &[T],
285 shape: (usize, usize),
286 sorted: bool,
287 ) -> SparseResult<Self> {
288 if rows.len() != cols.len() || rows.len() != data.len() {
289 return Err(SparseError::InconsistentData {
290 reason: "rows, cols, and data must have the same length".to_string(),
291 });
292 }
293
294 if rows.is_empty() {
295 let indptr = Array1::zeros(shape.0 + 1);
297 return Self::new(Array1::zeros(0), Array1::zeros(0), indptr, shape);
298 }
299
300 let nnz = rows.len();
301 let mut all_data: Vec<(usize, usize, T)> = Vec::with_capacity(nnz);
302
303 for i in 0..nnz {
304 if rows[i] >= shape.0 || cols[i] >= shape.1 {
305 return Err(SparseError::IndexOutOfBounds {
306 index: (rows[i], cols[i]),
307 shape,
308 });
309 }
310 all_data.push((rows[i], cols[i], data[i]));
311 }
312
313 if !sorted {
314 all_data.sort_by_key(|&(row, col_, _)| (row, col_));
315 }
316
317 let mut row_counts = vec![0; shape.0];
319 for &(row_, _, _) in &all_data {
320 row_counts[row_] += 1;
321 }
322
323 let mut indptr = Vec::with_capacity(shape.0 + 1);
325 indptr.push(0);
326 let mut cumsum = 0;
327 for &count in &row_counts {
328 cumsum += count;
329 indptr.push(cumsum);
330 }
331
332 let mut indices = Vec::with_capacity(nnz);
334 let mut values = Vec::with_capacity(nnz);
335
336 for (_, col, val) in all_data {
337 indices.push(col);
338 values.push(val);
339 }
340
341 Self::new(
342 Array1::from_vec(values),
343 Array1::from_vec(indices),
344 Array1::from_vec(indptr),
345 shape,
346 )
347 }
348
349 fn check_sorted_indices(indices: &Array1<usize>, indptr: &Array1<usize>) -> bool {
351 for row in 0..indptr.len() - 1 {
352 let start = indptr[row];
353 let end = indptr[row + 1];
354
355 for i in start..end.saturating_sub(1) {
356 if i + 1 < indices.len() && indices[i] > indices[i + 1] {
357 return false;
358 }
359 }
360 }
361 true
362 }
363
364 pub fn get_data(&self) -> &Array1<T> {
366 &self.data
367 }
368
369 pub fn get_indices(&self) -> &Array1<usize> {
371 &self.indices
372 }
373
374 pub fn get_indptr(&self) -> &Array1<usize> {
376 &self.indptr
377 }
378
379 pub fn nrows(&self) -> usize {
381 self.shape.0
382 }
383
384 pub fn ncols(&self) -> usize {
386 self.shape.1
387 }
388
389 pub fn shape(&self) -> (usize, usize) {
391 self.shape
392 }
393}
394
395impl<T> SparseArray<T> for CsrArray<T>
396where
397 T: Float
398 + Add<Output = T>
399 + Sub<Output = T>
400 + Mul<Output = T>
401 + Div<Output = T>
402 + Debug
403 + Copy
404 + 'static,
405{
406 fn shape(&self) -> (usize, usize) {
407 self.shape
408 }
409
410 fn nnz(&self) -> usize {
411 self.data.len()
412 }
413
414 fn dtype(&self) -> &str {
415 "float" }
417
418 fn to_array(&self) -> Array2<T> {
419 let (rows, cols) = self.shape;
420 let mut result = Array2::zeros((rows, cols));
421
422 for row in 0..rows {
423 let start = self.indptr[row];
424 let end = self.indptr[row + 1];
425
426 for i in start..end {
427 let col = self.indices[i];
428 result[[row, col]] = self.data[i];
429 }
430 }
431
432 result
433 }
434
435 fn toarray(&self) -> Array2<T> {
436 self.to_array()
437 }
438
439 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
440 Ok(Box::new(self.clone()))
443 }
444
445 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
446 Ok(Box::new(self.clone()))
447 }
448
449 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
450 Ok(Box::new(self.clone()))
453 }
454
455 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
456 Ok(Box::new(self.clone()))
459 }
460
461 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
462 Ok(Box::new(self.clone()))
465 }
466
467 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
468 Ok(Box::new(self.clone()))
471 }
472
473 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
474 Ok(Box::new(self.clone()))
477 }
478
479 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
480 let self_array = self.to_array();
483 let other_array = other.to_array();
484
485 if self.shape() != other.shape() {
486 return Err(SparseError::DimensionMismatch {
487 expected: self.shape().0,
488 found: other.shape().0,
489 });
490 }
491
492 let result = &self_array + &other_array;
493
494 let (rows, cols) = self.shape();
496 let mut data = Vec::new();
497 let mut indices = Vec::new();
498 let mut indptr = vec![0];
499
500 for row in 0..rows {
501 for col in 0..cols {
502 let val = result[[row, col]];
503 if !val.is_zero() {
504 data.push(val);
505 indices.push(col);
506 }
507 }
508 indptr.push(data.len());
509 }
510
511 CsrArray::new(
512 Array1::from_vec(data),
513 Array1::from_vec(indices),
514 Array1::from_vec(indptr),
515 self.shape(),
516 )
517 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
518 }
519
520 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
521 let self_array = self.to_array();
523 let other_array = other.to_array();
524
525 if self.shape() != other.shape() {
526 return Err(SparseError::DimensionMismatch {
527 expected: self.shape().0,
528 found: other.shape().0,
529 });
530 }
531
532 let result = &self_array - &other_array;
533
534 let (rows, cols) = self.shape();
536 let mut data = Vec::new();
537 let mut indices = Vec::new();
538 let mut indptr = vec![0];
539
540 for row in 0..rows {
541 for col in 0..cols {
542 let val = result[[row, col]];
543 if !val.is_zero() {
544 data.push(val);
545 indices.push(col);
546 }
547 }
548 indptr.push(data.len());
549 }
550
551 CsrArray::new(
552 Array1::from_vec(data),
553 Array1::from_vec(indices),
554 Array1::from_vec(indptr),
555 self.shape(),
556 )
557 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
558 }
559
560 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
561 let self_array = self.to_array();
564 let other_array = other.to_array();
565
566 if self.shape() != other.shape() {
567 return Err(SparseError::DimensionMismatch {
568 expected: self.shape().0,
569 found: other.shape().0,
570 });
571 }
572
573 let result = &self_array * &other_array;
574
575 let (rows, cols) = self.shape();
577 let mut data = Vec::new();
578 let mut indices = Vec::new();
579 let mut indptr = vec![0];
580
581 for row in 0..rows {
582 for col in 0..cols {
583 let val = result[[row, col]];
584 if !val.is_zero() {
585 data.push(val);
586 indices.push(col);
587 }
588 }
589 indptr.push(data.len());
590 }
591
592 CsrArray::new(
593 Array1::from_vec(data),
594 Array1::from_vec(indices),
595 Array1::from_vec(indptr),
596 self.shape(),
597 )
598 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
599 }
600
601 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
602 let self_array = self.to_array();
604 let other_array = other.to_array();
605
606 if self.shape() != other.shape() {
607 return Err(SparseError::DimensionMismatch {
608 expected: self.shape().0,
609 found: other.shape().0,
610 });
611 }
612
613 let result = &self_array / &other_array;
614
615 let (rows, cols) = self.shape();
617 let mut data = Vec::new();
618 let mut indices = Vec::new();
619 let mut indptr = vec![0];
620
621 for row in 0..rows {
622 for col in 0..cols {
623 let val = result[[row, col]];
624 if !val.is_zero() {
625 data.push(val);
626 indices.push(col);
627 }
628 }
629 indptr.push(data.len());
630 }
631
632 CsrArray::new(
633 Array1::from_vec(data),
634 Array1::from_vec(indices),
635 Array1::from_vec(indptr),
636 self.shape(),
637 )
638 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
639 }
640
641 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
642 let (m, n) = self.shape();
645 let (p, q) = other.shape();
646
647 if n != p {
648 return Err(SparseError::DimensionMismatch {
649 expected: n,
650 found: p,
651 });
652 }
653
654 let mut result = Array2::zeros((m, q));
655 let other_array = other.to_array();
656
657 for row in 0..m {
658 let start = self.indptr[row];
659 let end = self.indptr[row + 1];
660
661 for j in 0..q {
662 let mut sum = T::zero();
663 for idx in start..end {
664 let col = self.indices[idx];
665 sum = sum + self.data[idx] * other_array[[col, j]];
666 }
667 if !sum.is_zero() {
668 result[[row, j]] = sum;
669 }
670 }
671 }
672
673 let mut data = Vec::new();
675 let mut indices = Vec::new();
676 let mut indptr = vec![0];
677
678 for row in 0..m {
679 for col in 0..q {
680 let val = result[[row, col]];
681 if !val.is_zero() {
682 data.push(val);
683 indices.push(col);
684 }
685 }
686 indptr.push(data.len());
687 }
688
689 CsrArray::new(
690 Array1::from_vec(data),
691 Array1::from_vec(indices),
692 Array1::from_vec(indptr),
693 (m, q),
694 )
695 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
696 }
697
698 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
699 let (m, n) = self.shape();
700 if n != other.len() {
701 return Err(SparseError::DimensionMismatch {
702 expected: n,
703 found: other.len(),
704 });
705 }
706
707 let mut result = Array1::zeros(m);
708
709 for row in 0..m {
710 let start = self.indptr[row];
711 let end = self.indptr[row + 1];
712
713 let mut sum = T::zero();
714 for idx in start..end {
715 let col = self.indices[idx];
716 sum = sum + self.data[idx] * other[col];
717 }
718 result[row] = sum;
719 }
720
721 Ok(result)
722 }
723
724 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
725 let (rows, cols) = self.shape();
729 let mut row_indices = Vec::with_capacity(self.nnz());
730 let mut col_indices = Vec::with_capacity(self.nnz());
731 let mut values = Vec::with_capacity(self.nnz());
732
733 for row in 0..rows {
734 let start = self.indptr[row];
735 let end = self.indptr[row + 1];
736
737 for idx in start..end {
738 let col = self.indices[idx];
739 row_indices.push(col); col_indices.push(row);
741 values.push(self.data[idx]);
742 }
743 }
744
745 CsrArray::from_triplets(
747 &row_indices,
748 &col_indices,
749 &values,
750 (cols, rows), false,
752 )
753 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
754 }
755
756 fn copy(&self) -> Box<dyn SparseArray<T>> {
757 Box::new(self.clone())
758 }
759
760 fn get(&self, i: usize, j: usize) -> T {
761 if i >= self.shape.0 || j >= self.shape.1 {
762 return T::zero();
763 }
764
765 let start = self.indptr[i];
766 let end = self.indptr[i + 1];
767
768 for idx in start..end {
769 if self.indices[idx] == j {
770 return self.data[idx];
771 }
772 if self.has_sorted_indices && self.indices[idx] > j {
774 break;
775 }
776 }
777
778 T::zero()
779 }
780
781 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
782 if i >= self.shape.0 || j >= self.shape.1 {
785 return Err(SparseError::IndexOutOfBounds {
786 index: (i, j),
787 shape: self.shape,
788 });
789 }
790
791 let start = self.indptr[i];
792 let end = self.indptr[i + 1];
793
794 for idx in start..end {
796 if self.indices[idx] == j {
797 self.data[idx] = value;
799 return Ok(());
800 }
801 if self.has_sorted_indices && self.indices[idx] > j {
803 return Err(SparseError::NotImplemented(
806 "Inserting new elements in CSR format".to_string(),
807 ));
808 }
809 }
810
811 Err(SparseError::NotImplemented(
814 "Inserting new elements in CSR format".to_string(),
815 ))
816 }
817
818 fn eliminate_zeros(&mut self) {
819 let mut new_data = Vec::new();
821 let mut new_indices = Vec::new();
822 let mut new_indptr = vec![0];
823
824 let (rows, _) = self.shape();
825
826 for row in 0..rows {
827 let start = self.indptr[row];
828 let end = self.indptr[row + 1];
829
830 for idx in start..end {
831 if !self.data[idx].is_zero() {
832 new_data.push(self.data[idx]);
833 new_indices.push(self.indices[idx]);
834 }
835 }
836 new_indptr.push(new_data.len());
837 }
838
839 self.data = Array1::from_vec(new_data);
841 self.indices = Array1::from_vec(new_indices);
842 self.indptr = Array1::from_vec(new_indptr);
843 }
844
845 fn sort_indices(&mut self) {
846 if self.has_sorted_indices {
847 return;
848 }
849
850 let (rows, _) = self.shape();
851
852 for row in 0..rows {
853 let start = self.indptr[row];
854 let end = self.indptr[row + 1];
855
856 if start == end {
857 continue;
858 }
859
860 let mut row_data = Vec::with_capacity(end - start);
862 for idx in start..end {
863 row_data.push((self.indices[idx], self.data[idx]));
864 }
865
866 row_data.sort_by_key(|&(col_, _)| col_);
868
869 for (i, (col, val)) in row_data.into_iter().enumerate() {
871 self.indices[start + i] = col;
872 self.data[start + i] = val;
873 }
874 }
875
876 self.has_sorted_indices = true;
877 }
878
879 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
880 if self.has_sorted_indices {
881 return Box::new(self.clone());
882 }
883
884 let mut sorted = self.clone();
885 sorted.sort_indices();
886 Box::new(sorted)
887 }
888
889 fn has_sorted_indices(&self) -> bool {
890 self.has_sorted_indices
891 }
892
893 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
894 match axis {
895 None => {
896 let mut sum = T::zero();
898 for &val in self.data.iter() {
899 sum = sum + val;
900 }
901 Ok(SparseSum::Scalar(sum))
902 }
903 Some(0) => {
904 let (_, cols) = self.shape();
906 let mut result = vec![T::zero(); cols];
907
908 for row in 0..self.shape.0 {
909 let start = self.indptr[row];
910 let end = self.indptr[row + 1];
911
912 for idx in start..end {
913 let col = self.indices[idx];
914 result[col] = result[col] + self.data[idx];
915 }
916 }
917
918 let mut data = Vec::new();
920 let mut indices = Vec::new();
921 let mut indptr = vec![0];
922
923 for (col, &val) in result.iter().enumerate() {
924 if !val.is_zero() {
925 data.push(val);
926 indices.push(col);
927 }
928 }
929 indptr.push(data.len());
930
931 let result_array = CsrArray::new(
932 Array1::from_vec(data),
933 Array1::from_vec(indices),
934 Array1::from_vec(indptr),
935 (1, cols),
936 )?;
937
938 Ok(SparseSum::SparseArray(Box::new(result_array)))
939 }
940 Some(1) => {
941 let mut result = Vec::with_capacity(self.shape.0);
943
944 for row in 0..self.shape.0 {
945 let start = self.indptr[row];
946 let end = self.indptr[row + 1];
947
948 let mut row_sum = T::zero();
949 for idx in start..end {
950 row_sum = row_sum + self.data[idx];
951 }
952 result.push(row_sum);
953 }
954
955 let mut data = Vec::new();
957 let mut indices = Vec::new();
958 let mut indptr = vec![0];
959
960 for &val in result.iter() {
961 if !val.is_zero() {
962 data.push(val);
963 indices.push(0);
964 indptr.push(data.len());
965 } else {
966 indptr.push(data.len());
967 }
968 }
969
970 let result_array = CsrArray::new(
971 Array1::from_vec(data),
972 Array1::from_vec(indices),
973 Array1::from_vec(indptr),
974 (self.shape.0, 1),
975 )?;
976
977 Ok(SparseSum::SparseArray(Box::new(result_array)))
978 }
979 _ => Err(SparseError::InvalidAxis),
980 }
981 }
982
983 fn max(&self) -> T {
984 if self.data.is_empty() {
985 return T::neg_infinity();
986 }
987
988 let mut max_val = self.data[0];
989 for &val in self.data.iter().skip(1) {
990 if val > max_val {
991 max_val = val;
992 }
993 }
994
995 if max_val < T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
997 max_val = T::zero();
998 }
999
1000 max_val
1001 }
1002
1003 fn min(&self) -> T {
1004 if self.data.is_empty() {
1005 return T::infinity();
1006 }
1007
1008 let mut min_val = self.data[0];
1009 for &val in self.data.iter().skip(1) {
1010 if val < min_val {
1011 min_val = val;
1012 }
1013 }
1014
1015 if min_val > T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
1017 min_val = T::zero();
1018 }
1019
1020 min_val
1021 }
1022
1023 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
1024 let nnz = self.nnz();
1025 let mut rows = Vec::with_capacity(nnz);
1026 let mut cols = Vec::with_capacity(nnz);
1027 let mut values = Vec::with_capacity(nnz);
1028
1029 for row in 0..self.shape.0 {
1030 let start = self.indptr[row];
1031 let end = self.indptr[row + 1];
1032
1033 for idx in start..end {
1034 let col = self.indices[idx];
1035 rows.push(row);
1036 cols.push(col);
1037 values.push(self.data[idx]);
1038 }
1039 }
1040
1041 (
1042 Array1::from_vec(rows),
1043 Array1::from_vec(cols),
1044 Array1::from_vec(values),
1045 )
1046 }
1047
1048 fn slice(
1049 &self,
1050 row_range: (usize, usize),
1051 col_range: (usize, usize),
1052 ) -> SparseResult<Box<dyn SparseArray<T>>> {
1053 let (start_row, end_row) = row_range;
1054 let (start_col, end_col) = col_range;
1055
1056 if start_row >= self.shape.0
1057 || end_row > self.shape.0
1058 || start_col >= self.shape.1
1059 || end_col > self.shape.1
1060 {
1061 return Err(SparseError::InvalidSliceRange);
1062 }
1063
1064 if start_row >= end_row || start_col >= end_col {
1065 return Err(SparseError::InvalidSliceRange);
1066 }
1067
1068 let mut data = Vec::new();
1069 let mut indices = Vec::new();
1070 let mut indptr = vec![0];
1071
1072 for row in start_row..end_row {
1073 let start = self.indptr[row];
1074 let end = self.indptr[row + 1];
1075
1076 for idx in start..end {
1077 let col = self.indices[idx];
1078 if col >= start_col && col < end_col {
1079 data.push(self.data[idx]);
1080 indices.push(col - start_col);
1081 }
1082 }
1083 indptr.push(data.len());
1084 }
1085
1086 CsrArray::new(
1087 Array1::from_vec(data),
1088 Array1::from_vec(indices),
1089 Array1::from_vec(indptr),
1090 (end_row - start_row, end_col - start_col),
1091 )
1092 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
1093 }
1094
1095 fn as_any(&self) -> &dyn std::any::Any {
1096 self
1097 }
1098
1099 fn get_indptr(&self) -> Option<&Array1<usize>> {
1100 Some(&self.indptr)
1101 }
1102
1103 fn indptr(&self) -> Option<&Array1<usize>> {
1104 Some(&self.indptr)
1105 }
1106}
1107
1108impl<T> fmt::Debug for CsrArray<T>
1109where
1110 T: Float
1111 + Add<Output = T>
1112 + Sub<Output = T>
1113 + Mul<Output = T>
1114 + Div<Output = T>
1115 + Debug
1116 + Copy
1117 + 'static,
1118{
1119 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1120 write!(
1121 f,
1122 "CsrArray<{}x{}, nnz={}>",
1123 self.shape.0,
1124 self.shape.1,
1125 self.nnz()
1126 )
1127 }
1128}
1129
1130#[cfg(test)]
1131mod tests {
1132 use super::*;
1133 use approx::assert_relative_eq;
1134
1135 #[test]
1136 fn test_csr_array_construction() {
1137 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1138 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1139 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1140 let shape = (3, 3);
1141
1142 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1143
1144 assert_eq!(csr.shape(), (3, 3));
1145 assert_eq!(csr.nnz(), 5);
1146 assert_eq!(csr.get(0, 0), 1.0);
1147 assert_eq!(csr.get(0, 2), 2.0);
1148 assert_eq!(csr.get(1, 1), 3.0);
1149 assert_eq!(csr.get(2, 0), 4.0);
1150 assert_eq!(csr.get(2, 2), 5.0);
1151 assert_eq!(csr.get(0, 1), 0.0);
1152 }
1153
1154 #[test]
1155 fn test_csr_from_triplets() {
1156 let rows = vec![0, 0, 1, 2, 2];
1157 let cols = vec![0, 2, 1, 0, 2];
1158 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1159 let shape = (3, 3);
1160
1161 let csr = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
1162
1163 assert_eq!(csr.shape(), (3, 3));
1164 assert_eq!(csr.nnz(), 5);
1165 assert_eq!(csr.get(0, 0), 1.0);
1166 assert_eq!(csr.get(0, 2), 2.0);
1167 assert_eq!(csr.get(1, 1), 3.0);
1168 assert_eq!(csr.get(2, 0), 4.0);
1169 assert_eq!(csr.get(2, 2), 5.0);
1170 assert_eq!(csr.get(0, 1), 0.0);
1171 }
1172
1173 #[test]
1174 fn test_csr_array_to_array() {
1175 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1176 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1177 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1178 let shape = (3, 3);
1179
1180 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1181 let dense = csr.to_array();
1182
1183 assert_eq!(dense.shape(), &[3, 3]);
1184 assert_eq!(dense[[0, 0]], 1.0);
1185 assert_eq!(dense[[0, 1]], 0.0);
1186 assert_eq!(dense[[0, 2]], 2.0);
1187 assert_eq!(dense[[1, 0]], 0.0);
1188 assert_eq!(dense[[1, 1]], 3.0);
1189 assert_eq!(dense[[1, 2]], 0.0);
1190 assert_eq!(dense[[2, 0]], 4.0);
1191 assert_eq!(dense[[2, 1]], 0.0);
1192 assert_eq!(dense[[2, 2]], 5.0);
1193 }
1194
1195 #[test]
1196 fn test_csr_array_dot_vector() {
1197 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1198 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1199 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1200 let shape = (3, 3);
1201
1202 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1203 let vec = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1204
1205 let result = csr.dot_vector(&vec.view()).unwrap();
1206
1207 assert_eq!(result.len(), 3);
1209 assert_relative_eq!(result[0], 7.0);
1210 assert_relative_eq!(result[1], 6.0);
1211 assert_relative_eq!(result[2], 19.0);
1212 }
1213
1214 #[test]
1215 fn test_csr_array_sum() {
1216 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1217 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1218 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1219 let shape = (3, 3);
1220
1221 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1222
1223 if let SparseSum::Scalar(sum) = csr.sum(None).unwrap() {
1225 assert_relative_eq!(sum, 15.0);
1226 } else {
1227 panic!("Expected scalar sum");
1228 }
1229
1230 if let SparseSum::SparseArray(row_sum) = csr.sum(Some(0)).unwrap() {
1232 let row_sum_array = row_sum.to_array();
1233 assert_eq!(row_sum_array.shape(), &[1, 3]);
1234 assert_relative_eq!(row_sum_array[[0, 0]], 5.0);
1235 assert_relative_eq!(row_sum_array[[0, 1]], 3.0);
1236 assert_relative_eq!(row_sum_array[[0, 2]], 7.0);
1237 } else {
1238 panic!("Expected sparse array sum");
1239 }
1240
1241 if let SparseSum::SparseArray(col_sum) = csr.sum(Some(1)).unwrap() {
1243 let col_sum_array = col_sum.to_array();
1244 assert_eq!(col_sum_array.shape(), &[3, 1]);
1245 assert_relative_eq!(col_sum_array[[0, 0]], 3.0);
1246 assert_relative_eq!(col_sum_array[[1, 0]], 3.0);
1247 assert_relative_eq!(col_sum_array[[2, 0]], 9.0);
1248 } else {
1249 panic!("Expected sparse array sum");
1250 }
1251 }
1252}