1use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12use crate::sparray::{SparseArray, SparseSum};
13
14#[derive(Clone)]
30pub struct CsrArray<T>
31where
32 T: Float
33 + Add<Output = T>
34 + Sub<Output = T>
35 + Mul<Output = T>
36 + Div<Output = T>
37 + Debug
38 + Copy
39 + 'static,
40{
41 data: Array1<T>,
43 indices: Array1<usize>,
45 indptr: Array1<usize>,
47 shape: (usize, usize),
49 has_sorted_indices: bool,
51}
52
53impl<T> CsrArray<T>
54where
55 T: Float
56 + Add<Output = T>
57 + Sub<Output = T>
58 + Mul<Output = T>
59 + Div<Output = T>
60 + Debug
61 + Copy
62 + 'static,
63{
64 pub fn new(
78 data: Array1<T>,
79 indices: Array1<usize>,
80 indptr: Array1<usize>,
81 shape: (usize, usize),
82 ) -> SparseResult<Self> {
83 if data.len() != indices.len() {
85 return Err(SparseError::InconsistentData {
86 reason: "data and indices must have the same length".to_string(),
87 });
88 }
89
90 if indptr.len() != shape.0 + 1 {
91 return Err(SparseError::InconsistentData {
92 reason: format!(
93 "indptr length ({}) must be one more than the number of rows ({})",
94 indptr.len(),
95 shape.0
96 ),
97 });
98 }
99
100 if let Some(&max_idx) = indices.iter().max() {
101 if max_idx >= shape.1 {
102 return Err(SparseError::IndexOutOfBounds {
103 index: (0, max_idx),
104 shape,
105 });
106 }
107 }
108
109 if let Some((&last, &first)) = indptr.iter().next_back().zip(indptr.iter().next()) {
110 if first != 0 {
111 return Err(SparseError::InconsistentData {
112 reason: "first element of indptr must be 0".to_string(),
113 });
114 }
115
116 if last != data.len() {
117 return Err(SparseError::InconsistentData {
118 reason: format!(
119 "last element of indptr ({}) must equal data length ({})",
120 last,
121 data.len()
122 ),
123 });
124 }
125 }
126
127 let has_sorted_indices = Self::check_sorted_indices(&indices, &indptr);
128
129 Ok(Self {
130 data,
131 indices,
132 indptr,
133 shape,
134 has_sorted_indices,
135 })
136 }
137
138 pub fn from_triplets(
153 rows: &[usize],
154 cols: &[usize],
155 data: &[T],
156 shape: (usize, usize),
157 sorted: bool,
158 ) -> SparseResult<Self> {
159 if rows.len() != cols.len() || rows.len() != data.len() {
160 return Err(SparseError::InconsistentData {
161 reason: "rows, cols, and data must have the same length".to_string(),
162 });
163 }
164
165 if rows.is_empty() {
166 let indptr = Array1::zeros(shape.0 + 1);
168 return Self::new(Array1::zeros(0), Array1::zeros(0), indptr, shape);
169 }
170
171 let nnz = rows.len();
172 let mut all_data: Vec<(usize, usize, T)> = Vec::with_capacity(nnz);
173
174 for i in 0..nnz {
175 if rows[i] >= shape.0 || cols[i] >= shape.1 {
176 return Err(SparseError::IndexOutOfBounds {
177 index: (rows[i], cols[i]),
178 shape,
179 });
180 }
181 all_data.push((rows[i], cols[i], data[i]));
182 }
183
184 if !sorted {
185 all_data.sort_by_key(|&(row, col, _)| (row, col));
186 }
187
188 let mut row_counts = vec![0; shape.0];
190 for &(row, _, _) in &all_data {
191 row_counts[row] += 1;
192 }
193
194 let mut indptr = Vec::with_capacity(shape.0 + 1);
196 indptr.push(0);
197 let mut cumsum = 0;
198 for &count in &row_counts {
199 cumsum += count;
200 indptr.push(cumsum);
201 }
202
203 let mut indices = Vec::with_capacity(nnz);
205 let mut values = Vec::with_capacity(nnz);
206
207 for (_, col, val) in all_data {
208 indices.push(col);
209 values.push(val);
210 }
211
212 Self::new(
213 Array1::from_vec(values),
214 Array1::from_vec(indices),
215 Array1::from_vec(indptr),
216 shape,
217 )
218 }
219
220 fn check_sorted_indices(indices: &Array1<usize>, indptr: &Array1<usize>) -> bool {
222 for row in 0..indptr.len() - 1 {
223 let start = indptr[row];
224 let end = indptr[row + 1];
225
226 for i in start..end.saturating_sub(1) {
227 if i + 1 < indices.len() && indices[i] > indices[i + 1] {
228 return false;
229 }
230 }
231 }
232 true
233 }
234
235 pub fn get_data(&self) -> &Array1<T> {
237 &self.data
238 }
239
240 pub fn get_indices(&self) -> &Array1<usize> {
242 &self.indices
243 }
244
245 pub fn get_indptr(&self) -> &Array1<usize> {
247 &self.indptr
248 }
249}
250
251impl<T> SparseArray<T> for CsrArray<T>
252where
253 T: Float
254 + Add<Output = T>
255 + Sub<Output = T>
256 + Mul<Output = T>
257 + Div<Output = T>
258 + Debug
259 + Copy
260 + 'static,
261{
262 fn shape(&self) -> (usize, usize) {
263 self.shape
264 }
265
266 fn nnz(&self) -> usize {
267 self.data.len()
268 }
269
270 fn dtype(&self) -> &str {
271 "float" }
273
274 fn to_array(&self) -> Array2<T> {
275 let (rows, cols) = self.shape;
276 let mut result = Array2::zeros((rows, cols));
277
278 for row in 0..rows {
279 let start = self.indptr[row];
280 let end = self.indptr[row + 1];
281
282 for i in start..end {
283 let col = self.indices[i];
284 result[[row, col]] = self.data[i];
285 }
286 }
287
288 result
289 }
290
291 fn toarray(&self) -> Array2<T> {
292 self.to_array()
293 }
294
295 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
296 Ok(Box::new(self.clone()))
299 }
300
301 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
302 Ok(Box::new(self.clone()))
303 }
304
305 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
306 Ok(Box::new(self.clone()))
309 }
310
311 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
312 Ok(Box::new(self.clone()))
315 }
316
317 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
318 Ok(Box::new(self.clone()))
321 }
322
323 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
324 Ok(Box::new(self.clone()))
327 }
328
329 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
330 Ok(Box::new(self.clone()))
333 }
334
335 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
336 let self_array = self.to_array();
339 let other_array = other.to_array();
340
341 if self.shape() != other.shape() {
342 return Err(SparseError::DimensionMismatch {
343 expected: self.shape().0,
344 found: other.shape().0,
345 });
346 }
347
348 let result = &self_array + &other_array;
349
350 let (rows, cols) = self.shape();
352 let mut data = Vec::new();
353 let mut indices = Vec::new();
354 let mut indptr = vec![0];
355
356 for row in 0..rows {
357 for col in 0..cols {
358 let val = result[[row, col]];
359 if !val.is_zero() {
360 data.push(val);
361 indices.push(col);
362 }
363 }
364 indptr.push(data.len());
365 }
366
367 CsrArray::new(
368 Array1::from_vec(data),
369 Array1::from_vec(indices),
370 Array1::from_vec(indptr),
371 self.shape(),
372 )
373 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
374 }
375
376 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
377 let self_array = self.to_array();
379 let other_array = other.to_array();
380
381 if self.shape() != other.shape() {
382 return Err(SparseError::DimensionMismatch {
383 expected: self.shape().0,
384 found: other.shape().0,
385 });
386 }
387
388 let result = &self_array - &other_array;
389
390 let (rows, cols) = self.shape();
392 let mut data = Vec::new();
393 let mut indices = Vec::new();
394 let mut indptr = vec![0];
395
396 for row in 0..rows {
397 for col in 0..cols {
398 let val = result[[row, col]];
399 if !val.is_zero() {
400 data.push(val);
401 indices.push(col);
402 }
403 }
404 indptr.push(data.len());
405 }
406
407 CsrArray::new(
408 Array1::from_vec(data),
409 Array1::from_vec(indices),
410 Array1::from_vec(indptr),
411 self.shape(),
412 )
413 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
414 }
415
416 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
417 let self_array = self.to_array();
420 let other_array = other.to_array();
421
422 if self.shape() != other.shape() {
423 return Err(SparseError::DimensionMismatch {
424 expected: self.shape().0,
425 found: other.shape().0,
426 });
427 }
428
429 let result = &self_array * &other_array;
430
431 let (rows, cols) = self.shape();
433 let mut data = Vec::new();
434 let mut indices = Vec::new();
435 let mut indptr = vec![0];
436
437 for row in 0..rows {
438 for col in 0..cols {
439 let val = result[[row, col]];
440 if !val.is_zero() {
441 data.push(val);
442 indices.push(col);
443 }
444 }
445 indptr.push(data.len());
446 }
447
448 CsrArray::new(
449 Array1::from_vec(data),
450 Array1::from_vec(indices),
451 Array1::from_vec(indptr),
452 self.shape(),
453 )
454 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
455 }
456
457 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
458 let self_array = self.to_array();
460 let other_array = other.to_array();
461
462 if self.shape() != other.shape() {
463 return Err(SparseError::DimensionMismatch {
464 expected: self.shape().0,
465 found: other.shape().0,
466 });
467 }
468
469 let result = &self_array / &other_array;
470
471 let (rows, cols) = self.shape();
473 let mut data = Vec::new();
474 let mut indices = Vec::new();
475 let mut indptr = vec![0];
476
477 for row in 0..rows {
478 for col in 0..cols {
479 let val = result[[row, col]];
480 if !val.is_zero() {
481 data.push(val);
482 indices.push(col);
483 }
484 }
485 indptr.push(data.len());
486 }
487
488 CsrArray::new(
489 Array1::from_vec(data),
490 Array1::from_vec(indices),
491 Array1::from_vec(indptr),
492 self.shape(),
493 )
494 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
495 }
496
497 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
498 let (m, n) = self.shape();
501 let (p, q) = other.shape();
502
503 if n != p {
504 return Err(SparseError::DimensionMismatch {
505 expected: n,
506 found: p,
507 });
508 }
509
510 let mut result = Array2::zeros((m, q));
511 let other_array = other.to_array();
512
513 for row in 0..m {
514 let start = self.indptr[row];
515 let end = self.indptr[row + 1];
516
517 for j in 0..q {
518 let mut sum = T::zero();
519 for idx in start..end {
520 let col = self.indices[idx];
521 sum = sum + self.data[idx] * other_array[[col, j]];
522 }
523 if !sum.is_zero() {
524 result[[row, j]] = sum;
525 }
526 }
527 }
528
529 let mut data = Vec::new();
531 let mut indices = Vec::new();
532 let mut indptr = vec![0];
533
534 for row in 0..m {
535 for col in 0..q {
536 let val = result[[row, col]];
537 if !val.is_zero() {
538 data.push(val);
539 indices.push(col);
540 }
541 }
542 indptr.push(data.len());
543 }
544
545 CsrArray::new(
546 Array1::from_vec(data),
547 Array1::from_vec(indices),
548 Array1::from_vec(indptr),
549 (m, q),
550 )
551 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
552 }
553
554 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
555 let (m, n) = self.shape();
556 if n != other.len() {
557 return Err(SparseError::DimensionMismatch {
558 expected: n,
559 found: other.len(),
560 });
561 }
562
563 let mut result = Array1::zeros(m);
564
565 for row in 0..m {
566 let start = self.indptr[row];
567 let end = self.indptr[row + 1];
568
569 let mut sum = T::zero();
570 for idx in start..end {
571 let col = self.indices[idx];
572 sum = sum + self.data[idx] * other[col];
573 }
574 result[row] = sum;
575 }
576
577 Ok(result)
578 }
579
580 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
581 let (rows, cols) = self.shape();
585 let mut row_indices = Vec::with_capacity(self.nnz());
586 let mut col_indices = Vec::with_capacity(self.nnz());
587 let mut values = Vec::with_capacity(self.nnz());
588
589 for row in 0..rows {
590 let start = self.indptr[row];
591 let end = self.indptr[row + 1];
592
593 for idx in start..end {
594 let col = self.indices[idx];
595 row_indices.push(col); col_indices.push(row);
597 values.push(self.data[idx]);
598 }
599 }
600
601 CsrArray::from_triplets(
603 &row_indices,
604 &col_indices,
605 &values,
606 (cols, rows), false,
608 )
609 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
610 }
611
612 fn copy(&self) -> Box<dyn SparseArray<T>> {
613 Box::new(self.clone())
614 }
615
616 fn get(&self, i: usize, j: usize) -> T {
617 if i >= self.shape.0 || j >= self.shape.1 {
618 return T::zero();
619 }
620
621 let start = self.indptr[i];
622 let end = self.indptr[i + 1];
623
624 for idx in start..end {
625 if self.indices[idx] == j {
626 return self.data[idx];
627 }
628 if self.has_sorted_indices && self.indices[idx] > j {
630 break;
631 }
632 }
633
634 T::zero()
635 }
636
637 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
638 if i >= self.shape.0 || j >= self.shape.1 {
641 return Err(SparseError::IndexOutOfBounds {
642 index: (i, j),
643 shape: self.shape,
644 });
645 }
646
647 let start = self.indptr[i];
648 let end = self.indptr[i + 1];
649
650 for idx in start..end {
652 if self.indices[idx] == j {
653 self.data[idx] = value;
655 return Ok(());
656 }
657 if self.has_sorted_indices && self.indices[idx] > j {
659 return Err(SparseError::NotImplemented(
662 "Inserting new elements in CSR format".to_string(),
663 ));
664 }
665 }
666
667 Err(SparseError::NotImplemented(
670 "Inserting new elements in CSR format".to_string(),
671 ))
672 }
673
674 fn eliminate_zeros(&mut self) {
675 let mut new_data = Vec::new();
677 let mut new_indices = Vec::new();
678 let mut new_indptr = vec![0];
679
680 let (rows, _) = self.shape();
681
682 for row in 0..rows {
683 let start = self.indptr[row];
684 let end = self.indptr[row + 1];
685
686 for idx in start..end {
687 if !self.data[idx].is_zero() {
688 new_data.push(self.data[idx]);
689 new_indices.push(self.indices[idx]);
690 }
691 }
692 new_indptr.push(new_data.len());
693 }
694
695 self.data = Array1::from_vec(new_data);
697 self.indices = Array1::from_vec(new_indices);
698 self.indptr = Array1::from_vec(new_indptr);
699 }
700
701 fn sort_indices(&mut self) {
702 if self.has_sorted_indices {
703 return;
704 }
705
706 let (rows, _) = self.shape();
707
708 for row in 0..rows {
709 let start = self.indptr[row];
710 let end = self.indptr[row + 1];
711
712 if start == end {
713 continue;
714 }
715
716 let mut row_data = Vec::with_capacity(end - start);
718 for idx in start..end {
719 row_data.push((self.indices[idx], self.data[idx]));
720 }
721
722 row_data.sort_by_key(|&(col, _)| col);
724
725 for (i, (col, val)) in row_data.into_iter().enumerate() {
727 self.indices[start + i] = col;
728 self.data[start + i] = val;
729 }
730 }
731
732 self.has_sorted_indices = true;
733 }
734
735 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
736 if self.has_sorted_indices {
737 return Box::new(self.clone());
738 }
739
740 let mut sorted = self.clone();
741 sorted.sort_indices();
742 Box::new(sorted)
743 }
744
745 fn has_sorted_indices(&self) -> bool {
746 self.has_sorted_indices
747 }
748
749 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
750 match axis {
751 None => {
752 let mut sum = T::zero();
754 for &val in self.data.iter() {
755 sum = sum + val;
756 }
757 Ok(SparseSum::Scalar(sum))
758 }
759 Some(0) => {
760 let (_, cols) = self.shape();
762 let mut result = vec![T::zero(); cols];
763
764 for row in 0..self.shape.0 {
765 let start = self.indptr[row];
766 let end = self.indptr[row + 1];
767
768 for idx in start..end {
769 let col = self.indices[idx];
770 result[col] = result[col] + self.data[idx];
771 }
772 }
773
774 let mut data = Vec::new();
776 let mut indices = Vec::new();
777 let mut indptr = vec![0];
778
779 for (col, &val) in result.iter().enumerate() {
780 if !val.is_zero() {
781 data.push(val);
782 indices.push(col);
783 }
784 }
785 indptr.push(data.len());
786
787 let result_array = CsrArray::new(
788 Array1::from_vec(data),
789 Array1::from_vec(indices),
790 Array1::from_vec(indptr),
791 (1, cols),
792 )?;
793
794 Ok(SparseSum::SparseArray(Box::new(result_array)))
795 }
796 Some(1) => {
797 let mut result = Vec::with_capacity(self.shape.0);
799
800 for row in 0..self.shape.0 {
801 let start = self.indptr[row];
802 let end = self.indptr[row + 1];
803
804 let mut row_sum = T::zero();
805 for idx in start..end {
806 row_sum = row_sum + self.data[idx];
807 }
808 result.push(row_sum);
809 }
810
811 let mut data = Vec::new();
813 let mut indices = Vec::new();
814 let mut indptr = vec![0];
815
816 for &val in result.iter() {
817 if !val.is_zero() {
818 data.push(val);
819 indices.push(0);
820 indptr.push(data.len());
821 } else {
822 indptr.push(data.len());
823 }
824 }
825
826 let result_array = CsrArray::new(
827 Array1::from_vec(data),
828 Array1::from_vec(indices),
829 Array1::from_vec(indptr),
830 (self.shape.0, 1),
831 )?;
832
833 Ok(SparseSum::SparseArray(Box::new(result_array)))
834 }
835 _ => Err(SparseError::InvalidAxis),
836 }
837 }
838
839 fn max(&self) -> T {
840 if self.data.is_empty() {
841 return T::neg_infinity();
842 }
843
844 let mut max_val = self.data[0];
845 for &val in self.data.iter().skip(1) {
846 if val > max_val {
847 max_val = val;
848 }
849 }
850
851 if max_val < T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
853 max_val = T::zero();
854 }
855
856 max_val
857 }
858
859 fn min(&self) -> T {
860 if self.data.is_empty() {
861 return T::infinity();
862 }
863
864 let mut min_val = self.data[0];
865 for &val in self.data.iter().skip(1) {
866 if val < min_val {
867 min_val = val;
868 }
869 }
870
871 if min_val > T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
873 min_val = T::zero();
874 }
875
876 min_val
877 }
878
879 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
880 let nnz = self.nnz();
881 let mut rows = Vec::with_capacity(nnz);
882 let mut cols = Vec::with_capacity(nnz);
883 let mut values = Vec::with_capacity(nnz);
884
885 for row in 0..self.shape.0 {
886 let start = self.indptr[row];
887 let end = self.indptr[row + 1];
888
889 for idx in start..end {
890 let col = self.indices[idx];
891 rows.push(row);
892 cols.push(col);
893 values.push(self.data[idx]);
894 }
895 }
896
897 (
898 Array1::from_vec(rows),
899 Array1::from_vec(cols),
900 Array1::from_vec(values),
901 )
902 }
903
904 fn slice(
905 &self,
906 row_range: (usize, usize),
907 col_range: (usize, usize),
908 ) -> SparseResult<Box<dyn SparseArray<T>>> {
909 let (start_row, end_row) = row_range;
910 let (start_col, end_col) = col_range;
911
912 if start_row >= self.shape.0
913 || end_row > self.shape.0
914 || start_col >= self.shape.1
915 || end_col > self.shape.1
916 {
917 return Err(SparseError::InvalidSliceRange);
918 }
919
920 if start_row >= end_row || start_col >= end_col {
921 return Err(SparseError::InvalidSliceRange);
922 }
923
924 let mut data = Vec::new();
925 let mut indices = Vec::new();
926 let mut indptr = vec![0];
927
928 for row in start_row..end_row {
929 let start = self.indptr[row];
930 let end = self.indptr[row + 1];
931
932 for idx in start..end {
933 let col = self.indices[idx];
934 if col >= start_col && col < end_col {
935 data.push(self.data[idx]);
936 indices.push(col - start_col);
937 }
938 }
939 indptr.push(data.len());
940 }
941
942 CsrArray::new(
943 Array1::from_vec(data),
944 Array1::from_vec(indices),
945 Array1::from_vec(indptr),
946 (end_row - start_row, end_col - start_col),
947 )
948 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
949 }
950
951 fn as_any(&self) -> &dyn std::any::Any {
952 self
953 }
954}
955
956impl<T> fmt::Debug for CsrArray<T>
957where
958 T: Float
959 + Add<Output = T>
960 + Sub<Output = T>
961 + Mul<Output = T>
962 + Div<Output = T>
963 + Debug
964 + Copy
965 + 'static,
966{
967 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
968 write!(
969 f,
970 "CsrArray<{}x{}, nnz={}>",
971 self.shape.0,
972 self.shape.1,
973 self.nnz()
974 )
975 }
976}
977
978#[cfg(test)]
979mod tests {
980 use super::*;
981 use approx::assert_relative_eq;
982
983 #[test]
984 fn test_csr_array_construction() {
985 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
986 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
987 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
988 let shape = (3, 3);
989
990 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
991
992 assert_eq!(csr.shape(), (3, 3));
993 assert_eq!(csr.nnz(), 5);
994 assert_eq!(csr.get(0, 0), 1.0);
995 assert_eq!(csr.get(0, 2), 2.0);
996 assert_eq!(csr.get(1, 1), 3.0);
997 assert_eq!(csr.get(2, 0), 4.0);
998 assert_eq!(csr.get(2, 2), 5.0);
999 assert_eq!(csr.get(0, 1), 0.0);
1000 }
1001
1002 #[test]
1003 fn test_csr_from_triplets() {
1004 let rows = vec![0, 0, 1, 2, 2];
1005 let cols = vec![0, 2, 1, 0, 2];
1006 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1007 let shape = (3, 3);
1008
1009 let csr = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
1010
1011 assert_eq!(csr.shape(), (3, 3));
1012 assert_eq!(csr.nnz(), 5);
1013 assert_eq!(csr.get(0, 0), 1.0);
1014 assert_eq!(csr.get(0, 2), 2.0);
1015 assert_eq!(csr.get(1, 1), 3.0);
1016 assert_eq!(csr.get(2, 0), 4.0);
1017 assert_eq!(csr.get(2, 2), 5.0);
1018 assert_eq!(csr.get(0, 1), 0.0);
1019 }
1020
1021 #[test]
1022 fn test_csr_array_to_array() {
1023 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1024 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1025 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1026 let shape = (3, 3);
1027
1028 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1029 let dense = csr.to_array();
1030
1031 assert_eq!(dense.shape(), &[3, 3]);
1032 assert_eq!(dense[[0, 0]], 1.0);
1033 assert_eq!(dense[[0, 1]], 0.0);
1034 assert_eq!(dense[[0, 2]], 2.0);
1035 assert_eq!(dense[[1, 0]], 0.0);
1036 assert_eq!(dense[[1, 1]], 3.0);
1037 assert_eq!(dense[[1, 2]], 0.0);
1038 assert_eq!(dense[[2, 0]], 4.0);
1039 assert_eq!(dense[[2, 1]], 0.0);
1040 assert_eq!(dense[[2, 2]], 5.0);
1041 }
1042
1043 #[test]
1044 fn test_csr_array_dot_vector() {
1045 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1046 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1047 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1048 let shape = (3, 3);
1049
1050 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1051 let vec = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1052
1053 let result = csr.dot_vector(&vec.view()).unwrap();
1054
1055 assert_eq!(result.len(), 3);
1057 assert_relative_eq!(result[0], 7.0);
1058 assert_relative_eq!(result[1], 6.0);
1059 assert_relative_eq!(result[2], 19.0);
1060 }
1061
1062 #[test]
1063 fn test_csr_array_sum() {
1064 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1065 let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1066 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1067 let shape = (3, 3);
1068
1069 let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1070
1071 if let SparseSum::Scalar(sum) = csr.sum(None).unwrap() {
1073 assert_relative_eq!(sum, 15.0);
1074 } else {
1075 panic!("Expected scalar sum");
1076 }
1077
1078 if let SparseSum::SparseArray(row_sum) = csr.sum(Some(0)).unwrap() {
1080 let row_sum_array = row_sum.to_array();
1081 assert_eq!(row_sum_array.shape(), &[1, 3]);
1082 assert_relative_eq!(row_sum_array[[0, 0]], 5.0);
1083 assert_relative_eq!(row_sum_array[[0, 1]], 3.0);
1084 assert_relative_eq!(row_sum_array[[0, 2]], 7.0);
1085 } else {
1086 panic!("Expected sparse array sum");
1087 }
1088
1089 if let SparseSum::SparseArray(col_sum) = csr.sum(Some(1)).unwrap() {
1091 let col_sum_array = col_sum.to_array();
1092 assert_eq!(col_sum_array.shape(), &[3, 1]);
1093 assert_relative_eq!(col_sum_array[[0, 0]], 3.0);
1094 assert_relative_eq!(col_sum_array[[1, 0]], 3.0);
1095 assert_relative_eq!(col_sum_array[[2, 0]], 9.0);
1096 } else {
1097 panic!("Expected sparse array sum");
1098 }
1099 }
1100}