1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement, Zero};
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12use crate::sparray::{SparseArray, SparseSum};
13
14fn array1_insert<T: Clone + Default>(arr: &Array1<T>, idx: usize, value: T) -> Array1<T> {
18 let mut v = arr.to_vec();
19 v.insert(idx, value);
20 Array1::from_vec(v)
21}
22
23#[derive(Clone)]
115pub struct CsrArray<T>
116where
117 T: SparseElement + Div<Output = T> + 'static,
118{
119 data: Array1<T>,
121 indices: Array1<usize>,
123 indptr: Array1<usize>,
125 shape: (usize, usize),
127 has_sorted_indices: bool,
129}
130
131impl<T> CsrArray<T>
132where
133 T: SparseElement + Div<Output = T> + Zero + 'static,
134{
135 pub fn new(
149 data: Array1<T>,
150 indices: Array1<usize>,
151 indptr: Array1<usize>,
152 shape: (usize, usize),
153 ) -> SparseResult<Self> {
154 if data.len() != indices.len() {
156 return Err(SparseError::InconsistentData {
157 reason: "data and indices must have the same length".to_string(),
158 });
159 }
160
161 if indptr.len() != shape.0 + 1 {
162 return Err(SparseError::InconsistentData {
163 reason: format!(
164 "indptr length ({}) must be one more than the number of rows ({})",
165 indptr.len(),
166 shape.0
167 ),
168 });
169 }
170
171 if let Some(&max_idx) = indices.iter().max() {
172 if max_idx >= shape.1 {
173 return Err(SparseError::IndexOutOfBounds {
174 index: (0, max_idx),
175 shape,
176 });
177 }
178 }
179
180 if let Some((&last, &first)) = indptr.iter().next_back().zip(indptr.iter().next()) {
181 if first != 0 {
182 return Err(SparseError::InconsistentData {
183 reason: "first element of indptr must be 0".to_string(),
184 });
185 }
186
187 if last != data.len() {
188 return Err(SparseError::InconsistentData {
189 reason: format!(
190 "last element of indptr ({}) must equal data length ({})",
191 last,
192 data.len()
193 ),
194 });
195 }
196 }
197
198 let has_sorted_indices = Self::check_sorted_indices(&indices, &indptr);
199
200 Ok(Self {
201 data,
202 indices,
203 indptr,
204 shape,
205 has_sorted_indices,
206 })
207 }
208
209 pub fn from_triplets(
283 rows: &[usize],
284 cols: &[usize],
285 data: &[T],
286 shape: (usize, usize),
287 sorted: bool,
288 ) -> SparseResult<Self> {
289 if rows.len() != cols.len() || rows.len() != data.len() {
290 return Err(SparseError::InconsistentData {
291 reason: "rows, cols, and data must have the same length".to_string(),
292 });
293 }
294
295 if rows.is_empty() {
296 let indptr = Array1::zeros(shape.0 + 1);
298 return Self::new(Array1::zeros(0), Array1::zeros(0), indptr, shape);
299 }
300
301 let nnz = rows.len();
302 let mut all_data: Vec<(usize, usize, T)> = Vec::with_capacity(nnz);
303
304 for i in 0..nnz {
305 if rows[i] >= shape.0 || cols[i] >= shape.1 {
306 return Err(SparseError::IndexOutOfBounds {
307 index: (rows[i], cols[i]),
308 shape,
309 });
310 }
311 all_data.push((rows[i], cols[i], data[i]));
312 }
313
314 if !sorted {
315 all_data.sort_by_key(|&(row, col_, _)| (row, col_));
316 }
317
318 let mut row_counts = vec![0; shape.0];
320 for &(row_, _, _) in &all_data {
321 row_counts[row_] += 1;
322 }
323
324 let mut indptr = Vec::with_capacity(shape.0 + 1);
326 indptr.push(0);
327 let mut cumsum = 0;
328 for &count in &row_counts {
329 cumsum += count;
330 indptr.push(cumsum);
331 }
332
333 let mut indices = Vec::with_capacity(nnz);
335 let mut values = Vec::with_capacity(nnz);
336
337 for (_, col, val) in all_data {
338 indices.push(col);
339 values.push(val);
340 }
341
342 Self::new(
343 Array1::from_vec(values),
344 Array1::from_vec(indices),
345 Array1::from_vec(indptr),
346 shape,
347 )
348 }
349
350 fn check_sorted_indices(indices: &Array1<usize>, indptr: &Array1<usize>) -> bool {
352 for row in 0..indptr.len() - 1 {
353 let start = indptr[row];
354 let end = indptr[row + 1];
355
356 for i in start..end.saturating_sub(1) {
357 if i + 1 < indices.len() && indices[i] > indices[i + 1] {
358 return false;
359 }
360 }
361 }
362 true
363 }
364
365 pub fn get_data(&self) -> &Array1<T> {
367 &self.data
368 }
369
370 pub fn get_indices(&self) -> &Array1<usize> {
372 &self.indices
373 }
374
375 pub fn get_indptr(&self) -> &Array1<usize> {
377 &self.indptr
378 }
379
380 pub fn nrows(&self) -> usize {
382 self.shape.0
383 }
384
385 pub fn ncols(&self) -> usize {
387 self.shape.1
388 }
389
390 pub fn shape(&self) -> (usize, usize) {
392 self.shape
393 }
394}
395
396impl<T> SparseArray<T> for CsrArray<T>
397where
398 T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
399{
400 fn shape(&self) -> (usize, usize) {
401 self.shape
402 }
403
404 fn nnz(&self) -> usize {
405 self.data.len()
406 }
407
408 fn dtype(&self) -> &str {
409 "float" }
411
412 fn to_array(&self) -> Array2<T> {
413 let (rows, cols) = self.shape;
414 let mut result = Array2::zeros((rows, cols));
415
416 for row in 0..rows {
417 let start = self.indptr[row];
418 let end = self.indptr[row + 1];
419
420 for i in start..end {
421 let col = self.indices[i];
422 result[[row, col]] = self.data[i];
423 }
424 }
425
426 result
427 }
428
429 fn toarray(&self) -> Array2<T> {
430 self.to_array()
431 }
432
433 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
434 Ok(Box::new(self.clone()))
437 }
438
439 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
440 Ok(Box::new(self.clone()))
441 }
442
443 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
444 Ok(Box::new(self.clone()))
447 }
448
449 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
450 Ok(Box::new(self.clone()))
453 }
454
455 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
456 Ok(Box::new(self.clone()))
459 }
460
461 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
462 Ok(Box::new(self.clone()))
465 }
466
467 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
468 Ok(Box::new(self.clone()))
471 }
472
473 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
474 if self.shape() != other.shape() {
475 return Err(SparseError::DimensionMismatch {
476 expected: self.shape().0,
477 found: other.shape().0,
478 });
479 }
480
481 if let Some(other_csr) = other.as_any().downcast_ref::<CsrArray<T>>() {
484 if self.has_sorted_indices && other_csr.has_sorted_indices {
485 let (nrows, _) = self.shape();
486 let mut data = Vec::new();
487 let mut indices = Vec::new();
488 let mut indptr = vec![0usize];
489
490 for row in 0..nrows {
491 let a_start = self.indptr[row];
492 let a_end = self.indptr[row + 1];
493 let b_start = other_csr.indptr[row];
494 let b_end = other_csr.indptr[row + 1];
495
496 let a_cols = &self.indices.as_slice().unwrap_or(&[])[a_start..a_end];
497 let a_data = &self.data.as_slice().unwrap_or(&[])[a_start..a_end];
498 let b_cols = &other_csr.indices.as_slice().unwrap_or(&[])[b_start..b_end];
499 let b_data = &other_csr.data.as_slice().unwrap_or(&[])[b_start..b_end];
500
501 let mut ai = 0;
502 let mut bi = 0;
503 while ai < a_cols.len() && bi < b_cols.len() {
504 if a_cols[ai] < b_cols[bi] {
505 let val = a_data[ai];
506 if val != T::sparse_zero() {
507 data.push(val);
508 indices.push(a_cols[ai]);
509 }
510 ai += 1;
511 } else if a_cols[ai] > b_cols[bi] {
512 let val = b_data[bi];
513 if val != T::sparse_zero() {
514 data.push(val);
515 indices.push(b_cols[bi]);
516 }
517 bi += 1;
518 } else {
519 let val = a_data[ai] + b_data[bi];
520 if val != T::sparse_zero() {
521 data.push(val);
522 indices.push(a_cols[ai]);
523 }
524 ai += 1;
525 bi += 1;
526 }
527 }
528 while ai < a_cols.len() {
529 let val = a_data[ai];
530 if val != T::sparse_zero() {
531 data.push(val);
532 indices.push(a_cols[ai]);
533 }
534 ai += 1;
535 }
536 while bi < b_cols.len() {
537 let val = b_data[bi];
538 if val != T::sparse_zero() {
539 data.push(val);
540 indices.push(b_cols[bi]);
541 }
542 bi += 1;
543 }
544 indptr.push(data.len());
545 }
546
547 return CsrArray::new(
548 Array1::from_vec(data),
549 Array1::from_vec(indices),
550 Array1::from_vec(indptr),
551 self.shape(),
552 )
553 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>);
554 }
555 }
556
557 let self_array = self.to_array();
559 let other_array = other.to_array();
560 let result = &self_array + &other_array;
561
562 let (rows, cols) = self.shape();
563 let mut data = Vec::new();
564 let mut indices = Vec::new();
565 let mut indptr = vec![0];
566
567 for row in 0..rows {
568 for col in 0..cols {
569 let val = result[[row, col]];
570 if val != T::sparse_zero() {
571 data.push(val);
572 indices.push(col);
573 }
574 }
575 indptr.push(data.len());
576 }
577
578 CsrArray::new(
579 Array1::from_vec(data),
580 Array1::from_vec(indices),
581 Array1::from_vec(indptr),
582 self.shape(),
583 )
584 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
585 }
586
587 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
588 let self_array = self.to_array();
590 let other_array = other.to_array();
591
592 if self.shape() != other.shape() {
593 return Err(SparseError::DimensionMismatch {
594 expected: self.shape().0,
595 found: other.shape().0,
596 });
597 }
598
599 let result = &self_array - &other_array;
600
601 let (rows, cols) = self.shape();
603 let mut data = Vec::new();
604 let mut indices = Vec::new();
605 let mut indptr = vec![0];
606
607 for row in 0..rows {
608 for col in 0..cols {
609 let val = result[[row, col]];
610 if val != T::sparse_zero() {
611 data.push(val);
612 indices.push(col);
613 }
614 }
615 indptr.push(data.len());
616 }
617
618 CsrArray::new(
619 Array1::from_vec(data),
620 Array1::from_vec(indices),
621 Array1::from_vec(indptr),
622 self.shape(),
623 )
624 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
625 }
626
627 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
628 let self_array = self.to_array();
631 let other_array = other.to_array();
632
633 if self.shape() != other.shape() {
634 return Err(SparseError::DimensionMismatch {
635 expected: self.shape().0,
636 found: other.shape().0,
637 });
638 }
639
640 let result = &self_array * &other_array;
641
642 let (rows, cols) = self.shape();
644 let mut data = Vec::new();
645 let mut indices = Vec::new();
646 let mut indptr = vec![0];
647
648 for row in 0..rows {
649 for col in 0..cols {
650 let val = result[[row, col]];
651 if val != T::sparse_zero() {
652 data.push(val);
653 indices.push(col);
654 }
655 }
656 indptr.push(data.len());
657 }
658
659 CsrArray::new(
660 Array1::from_vec(data),
661 Array1::from_vec(indices),
662 Array1::from_vec(indptr),
663 self.shape(),
664 )
665 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
666 }
667
668 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
669 let self_array = self.to_array();
671 let other_array = other.to_array();
672
673 if self.shape() != other.shape() {
674 return Err(SparseError::DimensionMismatch {
675 expected: self.shape().0,
676 found: other.shape().0,
677 });
678 }
679
680 let result = &self_array / &other_array;
681
682 let (rows, cols) = self.shape();
684 let mut data = Vec::new();
685 let mut indices = Vec::new();
686 let mut indptr = vec![0];
687
688 for row in 0..rows {
689 for col in 0..cols {
690 let val = result[[row, col]];
691 if val != T::sparse_zero() {
692 data.push(val);
693 indices.push(col);
694 }
695 }
696 indptr.push(data.len());
697 }
698
699 CsrArray::new(
700 Array1::from_vec(data),
701 Array1::from_vec(indices),
702 Array1::from_vec(indptr),
703 self.shape(),
704 )
705 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
706 }
707
708 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
709 let (m, n) = self.shape();
710 let (p, q) = other.shape();
711
712 if n != p {
713 return Err(SparseError::DimensionMismatch {
714 expected: n,
715 found: p,
716 });
717 }
718
719 if let Some(other_csr) = other.as_any().downcast_ref::<CsrArray<T>>() {
722 let mut data = Vec::new();
723 let mut col_indices = Vec::new();
724 let mut indptr = vec![0usize];
725
726 let mut workspace = vec![T::sparse_zero(); q];
727 let mut marker = vec![false; q];
728
729 for i in 0..m {
730 let a_start = self.indptr[i];
731 let a_end = self.indptr[i + 1];
732 let mut touched: Vec<usize> = Vec::new();
733
734 for a_idx in a_start..a_end {
735 let k = self.indices[a_idx];
736 let a_ik = self.data[a_idx];
737 if a_ik == T::sparse_zero() {
738 continue;
739 }
740 let b_start = other_csr.indptr[k];
741 let b_end = other_csr.indptr[k + 1];
742 for b_idx in b_start..b_end {
743 let j = other_csr.indices[b_idx];
744 workspace[j] = workspace[j] + a_ik * other_csr.data[b_idx];
745 if !marker[j] {
746 marker[j] = true;
747 touched.push(j);
748 }
749 }
750 }
751
752 touched.sort_unstable();
753 for &j in &touched {
754 let val = workspace[j];
755 if val != T::sparse_zero() {
756 data.push(val);
757 col_indices.push(j);
758 }
759 workspace[j] = T::sparse_zero();
760 marker[j] = false;
761 }
762 indptr.push(data.len());
763 }
764
765 return CsrArray::new(
766 Array1::from_vec(data),
767 Array1::from_vec(col_indices),
768 Array1::from_vec(indptr),
769 (m, q),
770 )
771 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>);
772 }
773
774 let other_array = other.to_array();
776 let mut data = Vec::new();
777 let mut col_indices = Vec::new();
778 let mut indptr = vec![0];
779
780 for row in 0..m {
781 let start = self.indptr[row];
782 let end = self.indptr[row + 1];
783
784 for j in 0..q {
785 let mut sum = T::sparse_zero();
786 for idx in start..end {
787 let col = self.indices[idx];
788 sum = sum + self.data[idx] * other_array[[col, j]];
789 }
790 if sum != T::sparse_zero() {
791 data.push(sum);
792 col_indices.push(j);
793 }
794 }
795 indptr.push(data.len());
796 }
797
798 CsrArray::new(
799 Array1::from_vec(data),
800 Array1::from_vec(col_indices),
801 Array1::from_vec(indptr),
802 (m, q),
803 )
804 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
805 }
806
807 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
808 let (m, n) = self.shape();
809 if n != other.len() {
810 return Err(SparseError::DimensionMismatch {
811 expected: n,
812 found: other.len(),
813 });
814 }
815
816 let mut result = Array1::zeros(m);
817
818 for row in 0..m {
819 let start = self.indptr[row];
820 let end = self.indptr[row + 1];
821
822 let mut sum = T::sparse_zero();
823 for idx in start..end {
824 let col = self.indices[idx];
825 sum = sum + self.data[idx] * other[col];
826 }
827 result[row] = sum;
828 }
829
830 Ok(result)
831 }
832
833 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
834 let (rows, cols) = self.shape();
838 let mut row_indices = Vec::with_capacity(self.nnz());
839 let mut col_indices = Vec::with_capacity(self.nnz());
840 let mut values = Vec::with_capacity(self.nnz());
841
842 for row in 0..rows {
843 let start = self.indptr[row];
844 let end = self.indptr[row + 1];
845
846 for idx in start..end {
847 let col = self.indices[idx];
848 row_indices.push(col); col_indices.push(row);
850 values.push(self.data[idx]);
851 }
852 }
853
854 CsrArray::from_triplets(
856 &row_indices,
857 &col_indices,
858 &values,
859 (cols, rows), false,
861 )
862 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
863 }
864
865 fn copy(&self) -> Box<dyn SparseArray<T>> {
866 Box::new(self.clone())
867 }
868
869 fn get(&self, i: usize, j: usize) -> T {
870 if i >= self.shape.0 || j >= self.shape.1 {
871 return T::sparse_zero();
872 }
873
874 let start = self.indptr[i];
875 let end = self.indptr[i + 1];
876
877 for idx in start..end {
878 if self.indices[idx] == j {
879 return self.data[idx];
880 }
881 if self.has_sorted_indices && self.indices[idx] > j {
883 break;
884 }
885 }
886
887 T::sparse_zero()
888 }
889
890 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
891 if i >= self.shape.0 || j >= self.shape.1 {
892 return Err(SparseError::IndexOutOfBounds {
893 index: (i, j),
894 shape: self.shape,
895 });
896 }
897
898 let start = self.indptr[i];
899 let end = self.indptr[i + 1];
900
901 for idx in start..end {
903 if self.indices[idx] == j {
904 self.data[idx] = value;
905 return Ok(());
906 }
907 if self.has_sorted_indices && self.indices[idx] > j {
908 self.data = array1_insert(&self.data, idx, value);
910 self.indices = array1_insert(&self.indices, idx, j);
911 for row_ptr in self.indptr.iter_mut().skip(i + 1) {
913 *row_ptr += 1;
914 }
915 return Ok(());
916 }
917 }
918
919 self.data = array1_insert(&self.data, end, value);
921 self.indices = array1_insert(&self.indices, end, j);
922 for row_ptr in self.indptr.iter_mut().skip(i + 1) {
924 *row_ptr += 1;
925 }
926 if self.has_sorted_indices {
930 let new_end = self.indptr[i + 1];
931 let new_start = self.indptr[i];
932 for k in new_start..new_end.saturating_sub(1) {
933 if self.indices[k] > self.indices[k + 1] {
934 self.has_sorted_indices = false;
935 break;
936 }
937 }
938 }
939 Ok(())
940 }
941
942 fn eliminate_zeros(&mut self) {
943 let mut new_data = Vec::new();
945 let mut new_indices = Vec::new();
946 let mut new_indptr = vec![0];
947
948 let (rows, _) = self.shape();
949
950 for row in 0..rows {
951 let start = self.indptr[row];
952 let end = self.indptr[row + 1];
953
954 for idx in start..end {
955 if !SparseElement::is_zero(&self.data[idx]) {
956 new_data.push(self.data[idx]);
957 new_indices.push(self.indices[idx]);
958 }
959 }
960 new_indptr.push(new_data.len());
961 }
962
963 self.data = Array1::from_vec(new_data);
965 self.indices = Array1::from_vec(new_indices);
966 self.indptr = Array1::from_vec(new_indptr);
967 }
968
969 fn sort_indices(&mut self) {
970 if self.has_sorted_indices {
971 return;
972 }
973
974 let (rows, _) = self.shape();
975
976 for row in 0..rows {
977 let start = self.indptr[row];
978 let end = self.indptr[row + 1];
979
980 if start == end {
981 continue;
982 }
983
984 let mut row_data = Vec::with_capacity(end - start);
986 for idx in start..end {
987 row_data.push((self.indices[idx], self.data[idx]));
988 }
989
990 row_data.sort_by_key(|&(col_, _)| col_);
992
993 for (i, (col, val)) in row_data.into_iter().enumerate() {
995 self.indices[start + i] = col;
996 self.data[start + i] = val;
997 }
998 }
999
1000 self.has_sorted_indices = true;
1001 }
1002
1003 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
1004 if self.has_sorted_indices {
1005 return Box::new(self.clone());
1006 }
1007
1008 let mut sorted = self.clone();
1009 sorted.sort_indices();
1010 Box::new(sorted)
1011 }
1012
1013 fn has_sorted_indices(&self) -> bool {
1014 self.has_sorted_indices
1015 }
1016
1017 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
1018 match axis {
1019 None => {
1020 let mut sum = T::sparse_zero();
1022 for &val in self.data.iter() {
1023 sum = sum + val;
1024 }
1025 Ok(SparseSum::Scalar(sum))
1026 }
1027 Some(0) => {
1028 let (_, cols) = self.shape();
1030 let mut result = vec![T::sparse_zero(); cols];
1031
1032 for row in 0..self.shape.0 {
1033 let start = self.indptr[row];
1034 let end = self.indptr[row + 1];
1035
1036 for idx in start..end {
1037 let col = self.indices[idx];
1038 result[col] = result[col] + self.data[idx];
1039 }
1040 }
1041
1042 let mut data = Vec::new();
1044 let mut indices = Vec::new();
1045 let mut indptr = vec![0];
1046
1047 for (col, &val) in result.iter().enumerate() {
1048 if val != T::sparse_zero() {
1049 data.push(val);
1050 indices.push(col);
1051 }
1052 }
1053 indptr.push(data.len());
1054
1055 let result_array = CsrArray::new(
1056 Array1::from_vec(data),
1057 Array1::from_vec(indices),
1058 Array1::from_vec(indptr),
1059 (1, cols),
1060 )?;
1061
1062 Ok(SparseSum::SparseArray(Box::new(result_array)))
1063 }
1064 Some(1) => {
1065 let mut result = Vec::with_capacity(self.shape.0);
1067
1068 for row in 0..self.shape.0 {
1069 let start = self.indptr[row];
1070 let end = self.indptr[row + 1];
1071
1072 let mut row_sum = T::sparse_zero();
1073 for idx in start..end {
1074 row_sum = row_sum + self.data[idx];
1075 }
1076 result.push(row_sum);
1077 }
1078
1079 let mut data = Vec::new();
1081 let mut indices = Vec::new();
1082 let mut indptr = vec![0];
1083
1084 for &val in result.iter() {
1085 if val != T::sparse_zero() {
1086 data.push(val);
1087 indices.push(0);
1088 indptr.push(data.len());
1089 } else {
1090 indptr.push(data.len());
1091 }
1092 }
1093
1094 let result_array = CsrArray::new(
1095 Array1::from_vec(data),
1096 Array1::from_vec(indices),
1097 Array1::from_vec(indptr),
1098 (self.shape.0, 1),
1099 )?;
1100
1101 Ok(SparseSum::SparseArray(Box::new(result_array)))
1102 }
1103 _ => Err(SparseError::InvalidAxis),
1104 }
1105 }
1106
1107 fn max(&self) -> T {
1108 if self.data.is_empty() {
1109 return T::sparse_zero();
1111 }
1112
1113 let mut max_val = self.data[0];
1114 for &val in self.data.iter().skip(1) {
1115 if val > max_val {
1116 max_val = val;
1117 }
1118 }
1119
1120 let zero = T::sparse_zero();
1123 if max_val < zero && self.nnz() < self.shape.0 * self.shape.1 {
1124 max_val = zero;
1125 }
1126
1127 max_val
1128 }
1129
1130 fn min(&self) -> T {
1131 if self.data.is_empty() {
1132 return T::sparse_zero();
1134 }
1135
1136 let mut min_val = self.data[0];
1137 for &val in self.data.iter().skip(1) {
1138 if val < min_val {
1139 min_val = val;
1140 }
1141 }
1142
1143 let zero = T::sparse_zero();
1146 if min_val > zero && self.nnz() < self.shape.0 * self.shape.1 {
1147 min_val = zero;
1148 }
1149
1150 min_val
1151 }
1152
1153 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
1154 let nnz = self.nnz();
1155 let mut rows = Vec::with_capacity(nnz);
1156 let mut cols = Vec::with_capacity(nnz);
1157 let mut values = Vec::with_capacity(nnz);
1158
1159 for row in 0..self.shape.0 {
1160 let start = self.indptr[row];
1161 let end = self.indptr[row + 1];
1162
1163 for idx in start..end {
1164 let col = self.indices[idx];
1165 rows.push(row);
1166 cols.push(col);
1167 values.push(self.data[idx]);
1168 }
1169 }
1170
1171 (
1172 Array1::from_vec(rows),
1173 Array1::from_vec(cols),
1174 Array1::from_vec(values),
1175 )
1176 }
1177
1178 fn slice(
1179 &self,
1180 row_range: (usize, usize),
1181 col_range: (usize, usize),
1182 ) -> SparseResult<Box<dyn SparseArray<T>>> {
1183 let (start_row, end_row) = row_range;
1184 let (start_col, end_col) = col_range;
1185
1186 if start_row >= self.shape.0
1187 || end_row > self.shape.0
1188 || start_col >= self.shape.1
1189 || end_col > self.shape.1
1190 {
1191 return Err(SparseError::InvalidSliceRange);
1192 }
1193
1194 if start_row >= end_row || start_col >= end_col {
1195 return Err(SparseError::InvalidSliceRange);
1196 }
1197
1198 let mut data = Vec::new();
1199 let mut indices = Vec::new();
1200 let mut indptr = vec![0];
1201
1202 for row in start_row..end_row {
1203 let start = self.indptr[row];
1204 let end = self.indptr[row + 1];
1205
1206 for idx in start..end {
1207 let col = self.indices[idx];
1208 if col >= start_col && col < end_col {
1209 data.push(self.data[idx]);
1210 indices.push(col - start_col);
1211 }
1212 }
1213 indptr.push(data.len());
1214 }
1215
1216 CsrArray::new(
1217 Array1::from_vec(data),
1218 Array1::from_vec(indices),
1219 Array1::from_vec(indptr),
1220 (end_row - start_row, end_col - start_col),
1221 )
1222 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
1223 }
1224
1225 fn as_any(&self) -> &dyn std::any::Any {
1226 self
1227 }
1228
1229 fn get_indptr(&self) -> Option<&Array1<usize>> {
1230 Some(&self.indptr)
1231 }
1232
1233 fn indptr(&self) -> Option<&Array1<usize>> {
1234 Some(&self.indptr)
1235 }
1236}
1237
1238impl<T> fmt::Debug for CsrArray<T>
1239where
1240 T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
1241{
1242 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1243 write!(
1244 f,
1245 "CsrArray<{}x{}, nnz={}>",
1246 self.shape.0,
1247 self.shape.1,
1248 self.nnz()
1249 )
1250 }
1251}
1252
1253#[cfg(test)]
1254mod tests {
1255 use super::*;
1256 use approx::assert_relative_eq;
1257
1258 #[test]
1259 fn test_csr_array_construction() {
1260 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1261 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1262 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1263 let shape = (3, 3);
1264
1265 let csr = CsrArray::new(data, indices, indptr, shape).expect("Operation failed");
1266
1267 assert_eq!(csr.shape(), (3, 3));
1268 assert_eq!(csr.nnz(), 5);
1269 assert_eq!(csr.get(0, 0), 1.0);
1270 assert_eq!(csr.get(0, 2), 2.0);
1271 assert_eq!(csr.get(1, 1), 3.0);
1272 assert_eq!(csr.get(2, 0), 4.0);
1273 assert_eq!(csr.get(2, 2), 5.0);
1274 assert_eq!(csr.get(0, 1), 0.0);
1275 }
1276
1277 #[test]
1278 fn test_csr_from_triplets() {
1279 let rows = vec![0, 0, 1, 2, 2];
1280 let cols = vec![0, 2, 1, 0, 2];
1281 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1282 let shape = (3, 3);
1283
1284 let csr =
1285 CsrArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
1286
1287 assert_eq!(csr.shape(), (3, 3));
1288 assert_eq!(csr.nnz(), 5);
1289 assert_eq!(csr.get(0, 0), 1.0);
1290 assert_eq!(csr.get(0, 2), 2.0);
1291 assert_eq!(csr.get(1, 1), 3.0);
1292 assert_eq!(csr.get(2, 0), 4.0);
1293 assert_eq!(csr.get(2, 2), 5.0);
1294 assert_eq!(csr.get(0, 1), 0.0);
1295 }
1296
1297 #[test]
1298 fn test_csr_array_to_array() {
1299 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1300 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1301 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1302 let shape = (3, 3);
1303
1304 let csr = CsrArray::new(data, indices, indptr, shape).expect("Operation failed");
1305 let dense = csr.to_array();
1306
1307 assert_eq!(dense.shape(), &[3, 3]);
1308 assert_eq!(dense[[0, 0]], 1.0);
1309 assert_eq!(dense[[0, 1]], 0.0);
1310 assert_eq!(dense[[0, 2]], 2.0);
1311 assert_eq!(dense[[1, 0]], 0.0);
1312 assert_eq!(dense[[1, 1]], 3.0);
1313 assert_eq!(dense[[1, 2]], 0.0);
1314 assert_eq!(dense[[2, 0]], 4.0);
1315 assert_eq!(dense[[2, 1]], 0.0);
1316 assert_eq!(dense[[2, 2]], 5.0);
1317 }
1318
1319 #[test]
1320 fn test_csr_array_dot_vector() {
1321 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1322 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1323 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1324 let shape = (3, 3);
1325
1326 let csr = CsrArray::new(data, indices, indptr, shape).expect("Operation failed");
1327 let vec = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1328
1329 let result = csr.dot_vector(&vec.view()).expect("Operation failed");
1330
1331 assert_eq!(result.len(), 3);
1333 assert_relative_eq!(result[0], 7.0);
1334 assert_relative_eq!(result[1], 6.0);
1335 assert_relative_eq!(result[2], 19.0);
1336 }
1337
1338 #[test]
1339 fn test_csr_array_sum() {
1340 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1341 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1342 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1343 let shape = (3, 3);
1344
1345 let csr = CsrArray::new(data, indices, indptr, shape).expect("Operation failed");
1346
1347 if let SparseSum::Scalar(sum) = csr.sum(None).expect("Operation failed") {
1349 assert_relative_eq!(sum, 15.0);
1350 } else {
1351 panic!("Expected scalar sum");
1352 }
1353
1354 if let SparseSum::SparseArray(row_sum) = csr.sum(Some(0)).expect("Operation failed") {
1356 let row_sum_array = row_sum.to_array();
1357 assert_eq!(row_sum_array.shape(), &[1, 3]);
1358 assert_relative_eq!(row_sum_array[[0, 0]], 5.0);
1359 assert_relative_eq!(row_sum_array[[0, 1]], 3.0);
1360 assert_relative_eq!(row_sum_array[[0, 2]], 7.0);
1361 } else {
1362 panic!("Expected sparse array sum");
1363 }
1364
1365 if let SparseSum::SparseArray(col_sum) = csr.sum(Some(1)).expect("Operation failed") {
1367 let col_sum_array = col_sum.to_array();
1368 assert_eq!(col_sum_array.shape(), &[3, 1]);
1369 assert_relative_eq!(col_sum_array[[0, 0]], 3.0);
1370 assert_relative_eq!(col_sum_array[[1, 0]], 3.0);
1371 assert_relative_eq!(col_sum_array[[2, 0]], 9.0);
1372 } else {
1373 panic!("Expected sparse array sum");
1374 }
1375 }
1376}