1use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12use crate::sparray::{SparseArray, SparseSum};
13
14#[derive(Clone)]
30pub struct CsrArray<T>
31where
32 T: Float
33 + Add<Output = T>
34 + Sub<Output = T>
35 + Mul<Output = T>
36 + Div<Output = T>
37 + Debug
38 + Copy
39 + 'static,
40{
41 data: Array1<T>,
43 indices: Array1<usize>,
45 indptr: Array1<usize>,
47 shape: (usize, usize),
49 has_sorted_indices: bool,
51}
52
53impl<T> CsrArray<T>
54where
55 T: Float
56 + Add<Output = T>
57 + Sub<Output = T>
58 + Mul<Output = T>
59 + Div<Output = T>
60 + Debug
61 + Copy
62 + 'static,
63{
64 pub fn new(
78 data: Array1<T>,
79 indices: Array1<usize>,
80 indptr: Array1<usize>,
81 shape: (usize, usize),
82 ) -> SparseResult<Self> {
83 if data.len() != indices.len() {
85 return Err(SparseError::InconsistentData {
86 reason: "data and indices must have the same length".to_string(),
87 });
88 }
89
90 if indptr.len() != shape.0 + 1 {
91 return Err(SparseError::InconsistentData {
92 reason: format!(
93 "indptr length ({}) must be one more than the number of rows ({})",
94 indptr.len(),
95 shape.0
96 ),
97 });
98 }
99
100 if let Some(&max_idx) = indices.iter().max() {
101 if max_idx >= shape.1 {
102 return Err(SparseError::IndexOutOfBounds {
103 index: (0, max_idx),
104 shape,
105 });
106 }
107 }
108
109 if let Some((&last, &first)) = indptr.iter().next_back().zip(indptr.iter().next()) {
110 if first != 0 {
111 return Err(SparseError::InconsistentData {
112 reason: "first element of indptr must be 0".to_string(),
113 });
114 }
115
116 if last != data.len() {
117 return Err(SparseError::InconsistentData {
118 reason: format!(
119 "last element of indptr ({}) must equal data length ({})",
120 last,
121 data.len()
122 ),
123 });
124 }
125 }
126
127 let has_sorted_indices = Self::check_sorted_indices(&indices, &indptr);
128
129 Ok(Self {
130 data,
131 indices,
132 indptr,
133 shape,
134 has_sorted_indices,
135 })
136 }
137
138 pub fn from_triplets(
153 rows: &[usize],
154 cols: &[usize],
155 data: &[T],
156 shape: (usize, usize),
157 sorted: bool,
158 ) -> SparseResult<Self> {
159 if rows.len() != cols.len() || rows.len() != data.len() {
160 return Err(SparseError::InconsistentData {
161 reason: "rows, cols, and data must have the same length".to_string(),
162 });
163 }
164
165 if rows.is_empty() {
166 let indptr = Array1::zeros(shape.0 + 1);
168 return Self::new(Array1::zeros(0), Array1::zeros(0), indptr, shape);
169 }
170
171 let nnz = rows.len();
172 let mut all_data: Vec<(usize, usize, T)> = Vec::with_capacity(nnz);
173
174 for i in 0..nnz {
175 if rows[i] >= shape.0 || cols[i] >= shape.1 {
176 return Err(SparseError::IndexOutOfBounds {
177 index: (rows[i], cols[i]),
178 shape,
179 });
180 }
181 all_data.push((rows[i], cols[i], data[i]));
182 }
183
184 if !sorted {
185 all_data.sort_by_key(|&(row, col, _)| (row, col));
186 }
187
188 let mut row_counts = vec![0; shape.0];
190 for &(row, _, _) in &all_data {
191 row_counts[row] += 1;
192 }
193
194 let mut indptr = Vec::with_capacity(shape.0 + 1);
196 indptr.push(0);
197 let mut cumsum = 0;
198 for &count in &row_counts {
199 cumsum += count;
200 indptr.push(cumsum);
201 }
202
203 let mut indices = Vec::with_capacity(nnz);
205 let mut values = Vec::with_capacity(nnz);
206
207 for (_, col, val) in all_data {
208 indices.push(col);
209 values.push(val);
210 }
211
212 Self::new(
213 Array1::from_vec(values),
214 Array1::from_vec(indices),
215 Array1::from_vec(indptr),
216 shape,
217 )
218 }
219
220 fn check_sorted_indices(indices: &Array1<usize>, indptr: &Array1<usize>) -> bool {
222 for row in 0..indptr.len() - 1 {
223 let start = indptr[row];
224 let end = indptr[row + 1];
225
226 for i in start..end.saturating_sub(1) {
227 if i + 1 < indices.len() && indices[i] > indices[i + 1] {
228 return false;
229 }
230 }
231 }
232 true
233 }
234
235 pub fn get_data(&self) -> &Array1<T> {
237 &self.data
238 }
239
240 pub fn get_indices(&self) -> &Array1<usize> {
242 &self.indices
243 }
244
245 pub fn get_indptr(&self) -> &Array1<usize> {
247 &self.indptr
248 }
249
250 pub fn nrows(&self) -> usize {
252 self.shape.0
253 }
254
255 pub fn ncols(&self) -> usize {
257 self.shape.1
258 }
259
260 pub fn shape(&self) -> (usize, usize) {
262 self.shape
263 }
264}
265
266impl<T> SparseArray<T> for CsrArray<T>
267where
268 T: Float
269 + Add<Output = T>
270 + Sub<Output = T>
271 + Mul<Output = T>
272 + Div<Output = T>
273 + Debug
274 + Copy
275 + 'static,
276{
277 fn shape(&self) -> (usize, usize) {
278 self.shape
279 }
280
281 fn nnz(&self) -> usize {
282 self.data.len()
283 }
284
285 fn dtype(&self) -> &str {
286 "float" }
288
289 fn to_array(&self) -> Array2<T> {
290 let (rows, cols) = self.shape;
291 let mut result = Array2::zeros((rows, cols));
292
293 for row in 0..rows {
294 let start = self.indptr[row];
295 let end = self.indptr[row + 1];
296
297 for i in start..end {
298 let col = self.indices[i];
299 result[[row, col]] = self.data[i];
300 }
301 }
302
303 result
304 }
305
306 fn toarray(&self) -> Array2<T> {
307 self.to_array()
308 }
309
310 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
311 Ok(Box::new(self.clone()))
314 }
315
316 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
317 Ok(Box::new(self.clone()))
318 }
319
320 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
321 Ok(Box::new(self.clone()))
324 }
325
326 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
327 Ok(Box::new(self.clone()))
330 }
331
332 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
333 Ok(Box::new(self.clone()))
336 }
337
338 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
339 Ok(Box::new(self.clone()))
342 }
343
344 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
345 Ok(Box::new(self.clone()))
348 }
349
350 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
351 let self_array = self.to_array();
354 let other_array = other.to_array();
355
356 if self.shape() != other.shape() {
357 return Err(SparseError::DimensionMismatch {
358 expected: self.shape().0,
359 found: other.shape().0,
360 });
361 }
362
363 let result = &self_array + &other_array;
364
365 let (rows, cols) = self.shape();
367 let mut data = Vec::new();
368 let mut indices = Vec::new();
369 let mut indptr = vec![0];
370
371 for row in 0..rows {
372 for col in 0..cols {
373 let val = result[[row, col]];
374 if !val.is_zero() {
375 data.push(val);
376 indices.push(col);
377 }
378 }
379 indptr.push(data.len());
380 }
381
382 CsrArray::new(
383 Array1::from_vec(data),
384 Array1::from_vec(indices),
385 Array1::from_vec(indptr),
386 self.shape(),
387 )
388 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
389 }
390
391 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
392 let self_array = self.to_array();
394 let other_array = other.to_array();
395
396 if self.shape() != other.shape() {
397 return Err(SparseError::DimensionMismatch {
398 expected: self.shape().0,
399 found: other.shape().0,
400 });
401 }
402
403 let result = &self_array - &other_array;
404
405 let (rows, cols) = self.shape();
407 let mut data = Vec::new();
408 let mut indices = Vec::new();
409 let mut indptr = vec![0];
410
411 for row in 0..rows {
412 for col in 0..cols {
413 let val = result[[row, col]];
414 if !val.is_zero() {
415 data.push(val);
416 indices.push(col);
417 }
418 }
419 indptr.push(data.len());
420 }
421
422 CsrArray::new(
423 Array1::from_vec(data),
424 Array1::from_vec(indices),
425 Array1::from_vec(indptr),
426 self.shape(),
427 )
428 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
429 }
430
431 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
432 let self_array = self.to_array();
435 let other_array = other.to_array();
436
437 if self.shape() != other.shape() {
438 return Err(SparseError::DimensionMismatch {
439 expected: self.shape().0,
440 found: other.shape().0,
441 });
442 }
443
444 let result = &self_array * &other_array;
445
446 let (rows, cols) = self.shape();
448 let mut data = Vec::new();
449 let mut indices = Vec::new();
450 let mut indptr = vec![0];
451
452 for row in 0..rows {
453 for col in 0..cols {
454 let val = result[[row, col]];
455 if !val.is_zero() {
456 data.push(val);
457 indices.push(col);
458 }
459 }
460 indptr.push(data.len());
461 }
462
463 CsrArray::new(
464 Array1::from_vec(data),
465 Array1::from_vec(indices),
466 Array1::from_vec(indptr),
467 self.shape(),
468 )
469 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
470 }
471
472 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
473 let self_array = self.to_array();
475 let other_array = other.to_array();
476
477 if self.shape() != other.shape() {
478 return Err(SparseError::DimensionMismatch {
479 expected: self.shape().0,
480 found: other.shape().0,
481 });
482 }
483
484 let result = &self_array / &other_array;
485
486 let (rows, cols) = self.shape();
488 let mut data = Vec::new();
489 let mut indices = Vec::new();
490 let mut indptr = vec![0];
491
492 for row in 0..rows {
493 for col in 0..cols {
494 let val = result[[row, col]];
495 if !val.is_zero() {
496 data.push(val);
497 indices.push(col);
498 }
499 }
500 indptr.push(data.len());
501 }
502
503 CsrArray::new(
504 Array1::from_vec(data),
505 Array1::from_vec(indices),
506 Array1::from_vec(indptr),
507 self.shape(),
508 )
509 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
510 }
511
512 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
513 let (m, n) = self.shape();
516 let (p, q) = other.shape();
517
518 if n != p {
519 return Err(SparseError::DimensionMismatch {
520 expected: n,
521 found: p,
522 });
523 }
524
525 let mut result = Array2::zeros((m, q));
526 let other_array = other.to_array();
527
528 for row in 0..m {
529 let start = self.indptr[row];
530 let end = self.indptr[row + 1];
531
532 for j in 0..q {
533 let mut sum = T::zero();
534 for idx in start..end {
535 let col = self.indices[idx];
536 sum = sum + self.data[idx] * other_array[[col, j]];
537 }
538 if !sum.is_zero() {
539 result[[row, j]] = sum;
540 }
541 }
542 }
543
544 let mut data = Vec::new();
546 let mut indices = Vec::new();
547 let mut indptr = vec![0];
548
549 for row in 0..m {
550 for col in 0..q {
551 let val = result[[row, col]];
552 if !val.is_zero() {
553 data.push(val);
554 indices.push(col);
555 }
556 }
557 indptr.push(data.len());
558 }
559
560 CsrArray::new(
561 Array1::from_vec(data),
562 Array1::from_vec(indices),
563 Array1::from_vec(indptr),
564 (m, q),
565 )
566 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
567 }
568
569 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
570 let (m, n) = self.shape();
571 if n != other.len() {
572 return Err(SparseError::DimensionMismatch {
573 expected: n,
574 found: other.len(),
575 });
576 }
577
578 let mut result = Array1::zeros(m);
579
580 for row in 0..m {
581 let start = self.indptr[row];
582 let end = self.indptr[row + 1];
583
584 let mut sum = T::zero();
585 for idx in start..end {
586 let col = self.indices[idx];
587 sum = sum + self.data[idx] * other[col];
588 }
589 result[row] = sum;
590 }
591
592 Ok(result)
593 }
594
595 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
596 let (rows, cols) = self.shape();
600 let mut row_indices = Vec::with_capacity(self.nnz());
601 let mut col_indices = Vec::with_capacity(self.nnz());
602 let mut values = Vec::with_capacity(self.nnz());
603
604 for row in 0..rows {
605 let start = self.indptr[row];
606 let end = self.indptr[row + 1];
607
608 for idx in start..end {
609 let col = self.indices[idx];
610 row_indices.push(col); col_indices.push(row);
612 values.push(self.data[idx]);
613 }
614 }
615
616 CsrArray::from_triplets(
618 &row_indices,
619 &col_indices,
620 &values,
621 (cols, rows), false,
623 )
624 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
625 }
626
627 fn copy(&self) -> Box<dyn SparseArray<T>> {
628 Box::new(self.clone())
629 }
630
631 fn get(&self, i: usize, j: usize) -> T {
632 if i >= self.shape.0 || j >= self.shape.1 {
633 return T::zero();
634 }
635
636 let start = self.indptr[i];
637 let end = self.indptr[i + 1];
638
639 for idx in start..end {
640 if self.indices[idx] == j {
641 return self.data[idx];
642 }
643 if self.has_sorted_indices && self.indices[idx] > j {
645 break;
646 }
647 }
648
649 T::zero()
650 }
651
652 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
653 if i >= self.shape.0 || j >= self.shape.1 {
656 return Err(SparseError::IndexOutOfBounds {
657 index: (i, j),
658 shape: self.shape,
659 });
660 }
661
662 let start = self.indptr[i];
663 let end = self.indptr[i + 1];
664
665 for idx in start..end {
667 if self.indices[idx] == j {
668 self.data[idx] = value;
670 return Ok(());
671 }
672 if self.has_sorted_indices && self.indices[idx] > j {
674 return Err(SparseError::NotImplemented(
677 "Inserting new elements in CSR format".to_string(),
678 ));
679 }
680 }
681
682 Err(SparseError::NotImplemented(
685 "Inserting new elements in CSR format".to_string(),
686 ))
687 }
688
689 fn eliminate_zeros(&mut self) {
690 let mut new_data = Vec::new();
692 let mut new_indices = Vec::new();
693 let mut new_indptr = vec![0];
694
695 let (rows, _) = self.shape();
696
697 for row in 0..rows {
698 let start = self.indptr[row];
699 let end = self.indptr[row + 1];
700
701 for idx in start..end {
702 if !self.data[idx].is_zero() {
703 new_data.push(self.data[idx]);
704 new_indices.push(self.indices[idx]);
705 }
706 }
707 new_indptr.push(new_data.len());
708 }
709
710 self.data = Array1::from_vec(new_data);
712 self.indices = Array1::from_vec(new_indices);
713 self.indptr = Array1::from_vec(new_indptr);
714 }
715
716 fn sort_indices(&mut self) {
717 if self.has_sorted_indices {
718 return;
719 }
720
721 let (rows, _) = self.shape();
722
723 for row in 0..rows {
724 let start = self.indptr[row];
725 let end = self.indptr[row + 1];
726
727 if start == end {
728 continue;
729 }
730
731 let mut row_data = Vec::with_capacity(end - start);
733 for idx in start..end {
734 row_data.push((self.indices[idx], self.data[idx]));
735 }
736
737 row_data.sort_by_key(|&(col, _)| col);
739
740 for (i, (col, val)) in row_data.into_iter().enumerate() {
742 self.indices[start + i] = col;
743 self.data[start + i] = val;
744 }
745 }
746
747 self.has_sorted_indices = true;
748 }
749
750 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
751 if self.has_sorted_indices {
752 return Box::new(self.clone());
753 }
754
755 let mut sorted = self.clone();
756 sorted.sort_indices();
757 Box::new(sorted)
758 }
759
760 fn has_sorted_indices(&self) -> bool {
761 self.has_sorted_indices
762 }
763
764 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
765 match axis {
766 None => {
767 let mut sum = T::zero();
769 for &val in self.data.iter() {
770 sum = sum + val;
771 }
772 Ok(SparseSum::Scalar(sum))
773 }
774 Some(0) => {
775 let (_, cols) = self.shape();
777 let mut result = vec![T::zero(); cols];
778
779 for row in 0..self.shape.0 {
780 let start = self.indptr[row];
781 let end = self.indptr[row + 1];
782
783 for idx in start..end {
784 let col = self.indices[idx];
785 result[col] = result[col] + self.data[idx];
786 }
787 }
788
789 let mut data = Vec::new();
791 let mut indices = Vec::new();
792 let mut indptr = vec![0];
793
794 for (col, &val) in result.iter().enumerate() {
795 if !val.is_zero() {
796 data.push(val);
797 indices.push(col);
798 }
799 }
800 indptr.push(data.len());
801
802 let result_array = CsrArray::new(
803 Array1::from_vec(data),
804 Array1::from_vec(indices),
805 Array1::from_vec(indptr),
806 (1, cols),
807 )?;
808
809 Ok(SparseSum::SparseArray(Box::new(result_array)))
810 }
811 Some(1) => {
812 let mut result = Vec::with_capacity(self.shape.0);
814
815 for row in 0..self.shape.0 {
816 let start = self.indptr[row];
817 let end = self.indptr[row + 1];
818
819 let mut row_sum = T::zero();
820 for idx in start..end {
821 row_sum = row_sum + self.data[idx];
822 }
823 result.push(row_sum);
824 }
825
826 let mut data = Vec::new();
828 let mut indices = Vec::new();
829 let mut indptr = vec![0];
830
831 for &val in result.iter() {
832 if !val.is_zero() {
833 data.push(val);
834 indices.push(0);
835 indptr.push(data.len());
836 } else {
837 indptr.push(data.len());
838 }
839 }
840
841 let result_array = CsrArray::new(
842 Array1::from_vec(data),
843 Array1::from_vec(indices),
844 Array1::from_vec(indptr),
845 (self.shape.0, 1),
846 )?;
847
848 Ok(SparseSum::SparseArray(Box::new(result_array)))
849 }
850 _ => Err(SparseError::InvalidAxis),
851 }
852 }
853
854 fn max(&self) -> T {
855 if self.data.is_empty() {
856 return T::neg_infinity();
857 }
858
859 let mut max_val = self.data[0];
860 for &val in self.data.iter().skip(1) {
861 if val > max_val {
862 max_val = val;
863 }
864 }
865
866 if max_val < T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
868 max_val = T::zero();
869 }
870
871 max_val
872 }
873
874 fn min(&self) -> T {
875 if self.data.is_empty() {
876 return T::infinity();
877 }
878
879 let mut min_val = self.data[0];
880 for &val in self.data.iter().skip(1) {
881 if val < min_val {
882 min_val = val;
883 }
884 }
885
886 if min_val > T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
888 min_val = T::zero();
889 }
890
891 min_val
892 }
893
894 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
895 let nnz = self.nnz();
896 let mut rows = Vec::with_capacity(nnz);
897 let mut cols = Vec::with_capacity(nnz);
898 let mut values = Vec::with_capacity(nnz);
899
900 for row in 0..self.shape.0 {
901 let start = self.indptr[row];
902 let end = self.indptr[row + 1];
903
904 for idx in start..end {
905 let col = self.indices[idx];
906 rows.push(row);
907 cols.push(col);
908 values.push(self.data[idx]);
909 }
910 }
911
912 (
913 Array1::from_vec(rows),
914 Array1::from_vec(cols),
915 Array1::from_vec(values),
916 )
917 }
918
919 fn slice(
920 &self,
921 row_range: (usize, usize),
922 col_range: (usize, usize),
923 ) -> SparseResult<Box<dyn SparseArray<T>>> {
924 let (start_row, end_row) = row_range;
925 let (start_col, end_col) = col_range;
926
927 if start_row >= self.shape.0
928 || end_row > self.shape.0
929 || start_col >= self.shape.1
930 || end_col > self.shape.1
931 {
932 return Err(SparseError::InvalidSliceRange);
933 }
934
935 if start_row >= end_row || start_col >= end_col {
936 return Err(SparseError::InvalidSliceRange);
937 }
938
939 let mut data = Vec::new();
940 let mut indices = Vec::new();
941 let mut indptr = vec![0];
942
943 for row in start_row..end_row {
944 let start = self.indptr[row];
945 let end = self.indptr[row + 1];
946
947 for idx in start..end {
948 let col = self.indices[idx];
949 if col >= start_col && col < end_col {
950 data.push(self.data[idx]);
951 indices.push(col - start_col);
952 }
953 }
954 indptr.push(data.len());
955 }
956
957 CsrArray::new(
958 Array1::from_vec(data),
959 Array1::from_vec(indices),
960 Array1::from_vec(indptr),
961 (end_row - start_row, end_col - start_col),
962 )
963 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
964 }
965
966 fn as_any(&self) -> &dyn std::any::Any {
967 self
968 }
969}
970
971impl<T> fmt::Debug for CsrArray<T>
972where
973 T: Float
974 + Add<Output = T>
975 + Sub<Output = T>
976 + Mul<Output = T>
977 + Div<Output = T>
978 + Debug
979 + Copy
980 + 'static,
981{
982 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
983 write!(
984 f,
985 "CsrArray<{}x{}, nnz={}>",
986 self.shape.0,
987 self.shape.1,
988 self.nnz()
989 )
990 }
991}
992
993#[cfg(test)]
994mod tests {
995 use super::*;
996 use approx::assert_relative_eq;
997
998 #[test]
999 fn test_csr_array_construction() {
1000 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1001 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1002 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1003 let shape = (3, 3);
1004
1005 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1006
1007 assert_eq!(csr.shape(), (3, 3));
1008 assert_eq!(csr.nnz(), 5);
1009 assert_eq!(csr.get(0, 0), 1.0);
1010 assert_eq!(csr.get(0, 2), 2.0);
1011 assert_eq!(csr.get(1, 1), 3.0);
1012 assert_eq!(csr.get(2, 0), 4.0);
1013 assert_eq!(csr.get(2, 2), 5.0);
1014 assert_eq!(csr.get(0, 1), 0.0);
1015 }
1016
1017 #[test]
1018 fn test_csr_from_triplets() {
1019 let rows = vec![0, 0, 1, 2, 2];
1020 let cols = vec![0, 2, 1, 0, 2];
1021 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1022 let shape = (3, 3);
1023
1024 let csr = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
1025
1026 assert_eq!(csr.shape(), (3, 3));
1027 assert_eq!(csr.nnz(), 5);
1028 assert_eq!(csr.get(0, 0), 1.0);
1029 assert_eq!(csr.get(0, 2), 2.0);
1030 assert_eq!(csr.get(1, 1), 3.0);
1031 assert_eq!(csr.get(2, 0), 4.0);
1032 assert_eq!(csr.get(2, 2), 5.0);
1033 assert_eq!(csr.get(0, 1), 0.0);
1034 }
1035
1036 #[test]
1037 fn test_csr_array_to_array() {
1038 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1039 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1040 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1041 let shape = (3, 3);
1042
1043 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1044 let dense = csr.to_array();
1045
1046 assert_eq!(dense.shape(), &[3, 3]);
1047 assert_eq!(dense[[0, 0]], 1.0);
1048 assert_eq!(dense[[0, 1]], 0.0);
1049 assert_eq!(dense[[0, 2]], 2.0);
1050 assert_eq!(dense[[1, 0]], 0.0);
1051 assert_eq!(dense[[1, 1]], 3.0);
1052 assert_eq!(dense[[1, 2]], 0.0);
1053 assert_eq!(dense[[2, 0]], 4.0);
1054 assert_eq!(dense[[2, 1]], 0.0);
1055 assert_eq!(dense[[2, 2]], 5.0);
1056 }
1057
1058 #[test]
1059 fn test_csr_array_dot_vector() {
1060 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1061 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1062 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1063 let shape = (3, 3);
1064
1065 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1066 let vec = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1067
1068 let result = csr.dot_vector(&vec.view()).unwrap();
1069
1070 assert_eq!(result.len(), 3);
1072 assert_relative_eq!(result[0], 7.0);
1073 assert_relative_eq!(result[1], 6.0);
1074 assert_relative_eq!(result[2], 19.0);
1075 }
1076
1077 #[test]
1078 fn test_csr_array_sum() {
1079 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1080 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1081 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1082 let shape = (3, 3);
1083
1084 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1085
1086 if let SparseSum::Scalar(sum) = csr.sum(None).unwrap() {
1088 assert_relative_eq!(sum, 15.0);
1089 } else {
1090 panic!("Expected scalar sum");
1091 }
1092
1093 if let SparseSum::SparseArray(row_sum) = csr.sum(Some(0)).unwrap() {
1095 let row_sum_array = row_sum.to_array();
1096 assert_eq!(row_sum_array.shape(), &[1, 3]);
1097 assert_relative_eq!(row_sum_array[[0, 0]], 5.0);
1098 assert_relative_eq!(row_sum_array[[0, 1]], 3.0);
1099 assert_relative_eq!(row_sum_array[[0, 2]], 7.0);
1100 } else {
1101 panic!("Expected sparse array sum");
1102 }
1103
1104 if let SparseSum::SparseArray(col_sum) = csr.sum(Some(1)).unwrap() {
1106 let col_sum_array = col_sum.to_array();
1107 assert_eq!(col_sum_array.shape(), &[3, 1]);
1108 assert_relative_eq!(col_sum_array[[0, 0]], 3.0);
1109 assert_relative_eq!(col_sum_array[[1, 0]], 3.0);
1110 assert_relative_eq!(col_sum_array[[2, 0]], 9.0);
1111 } else {
1112 panic!("Expected sparse array sum");
1113 }
1114 }
1115}