1use ndarray::ArrayView;
11use num_traits::{Float, Num, Signed, Zero};
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14use std::cmp;
15use std::default::Default;
16use std::iter::{Enumerate, Zip};
17use std::mem;
18use std::ops::{Add, Deref, DerefMut, Index, IndexMut, Mul, MulAssign};
19use std::slice::Iter;
20
21use crate::{Ix1, Ix2, Shape};
22use ndarray::linalg::Dot;
23use ndarray::{self, Array, ArrayBase, ShapeBuilder};
24
25use crate::indexing::SpIndex;
26
27use crate::errors::StructureError;
28use crate::sparse::binop;
29use crate::sparse::permutation::PermViewI;
30use crate::sparse::prelude::*;
31use crate::sparse::prod;
32use crate::sparse::smmp;
33use crate::sparse::to_dense::assign_to_dense;
34use crate::sparse::utils;
35use crate::sparse::vec;
36
37#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
39#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
40#[allow(clippy::upper_case_acronyms)]
41pub enum CompressedStorage {
42 CSR,
44 CSC,
46}
47
48impl CompressedStorage {
49 pub fn other_storage(self) -> Self {
51 match self {
52 CSR => CSC,
53 CSC => CSR,
54 }
55 }
56}
57
58pub fn outer_dimension(
59 storage: CompressedStorage,
60 rows: usize,
61 cols: usize,
62) -> usize {
63 match storage {
64 CSR => rows,
65 CSC => cols,
66 }
67}
68
69pub fn inner_dimension(
70 storage: CompressedStorage,
71 rows: usize,
72 cols: usize,
73) -> usize {
74 match storage {
75 CSR => cols,
76 CSC => rows,
77 }
78}
79
80pub use self::CompressedStorage::{CSC, CSR};
81
82#[derive(Clone, Copy, PartialEq, Eq, Debug)]
83pub struct NnzIndex(pub usize);
88
89pub struct CsIter<'a, N: 'a, I: 'a, Iptr: 'a = I>
90where
91 I: SpIndex,
92 Iptr: SpIndex,
93{
94 storage: CompressedStorage,
95 cur_outer: I,
96 indptr: crate::IndPtrView<'a, Iptr>,
97 inner_iter: Enumerate<Zip<Iter<'a, I>, Iter<'a, N>>>,
98}
99
100impl<'a, N, I, Iptr> Iterator for CsIter<'a, N, I, Iptr>
101where
102 I: SpIndex,
103 Iptr: SpIndex,
104 N: 'a,
105{
106 type Item = (&'a N, (I, I));
107 fn next(&mut self) -> Option<<Self as Iterator>::Item> {
108 match self.inner_iter.next() {
109 None => None,
110 Some((nnz_index, (&inner_ind, val))) => {
111 loop {
115 let nnz_end = self
116 .indptr
117 .outer_inds_sz(self.cur_outer.index_unchecked())
118 .end;
119 if nnz_index == nnz_end.index_unchecked() {
120 self.cur_outer += I::one();
121 } else {
122 break;
123 }
124 }
125 let (row, col) = match self.storage {
126 CSR => (self.cur_outer, inner_ind),
127 CSC => (inner_ind, self.cur_outer),
128 };
129 Some((val, (row, col)))
130 }
131 }
132 }
133
134 fn size_hint(&self) -> (usize, Option<usize>) {
135 self.inner_iter.size_hint()
136 }
137}
138
139impl<N, I: SpIndex, Iptr: SpIndex, IptrStorage, IStorage, DStorage>
140 CsMatBase<N, I, IptrStorage, IStorage, DStorage, Iptr>
141where
142 IptrStorage: Deref<Target = [Iptr]>,
143 IStorage: Deref<Target = [I]>,
144 DStorage: Deref<Target = [N]>,
145{
146 pub(crate) fn new_checked(
147 storage: CompressedStorage,
148 shape: (usize, usize),
149 indptr: IptrStorage,
150 indices: IStorage,
151 data: DStorage,
152 ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)> {
153 let (nrows, ncols) = shape;
154 let (inner, outer) = match storage {
155 CSR => (ncols, nrows),
156 CSC => (nrows, ncols),
157 };
158 if data.len() != indices.len() {
159 return Err((
160 indptr,
161 indices,
162 data,
163 StructureError::SizeMismatch(
164 "data and indices have different sizes",
165 ),
166 ));
167 }
168 match crate::sparse::utils::check_compressed_structure(
169 inner,
170 outer,
171 indptr.as_ref(),
172 indices.as_ref(),
173 ) {
174 Err(e) => Err((indptr, indices, data, e)),
175 Ok(_) => Ok(Self {
176 storage,
177 nrows,
178 ncols,
179 indptr: crate::IndPtrBase::new_trusted(indptr),
180 indices,
181 data,
182 }),
183 }
184 }
185
186 pub fn new(
208 shape: (usize, usize),
209 indptr: IptrStorage,
210 indices: IStorage,
211 data: DStorage,
212 ) -> Self {
213 Self::new_checked(CompressedStorage::CSR, shape, indptr, indices, data)
214 .map_err(|(_, _, _, e)| e)
215 .unwrap()
216 }
217
218 pub fn new_csc(
222 shape: (usize, usize),
223 indptr: IptrStorage,
224 indices: IStorage,
225 data: DStorage,
226 ) -> Self {
227 Self::new_checked(CompressedStorage::CSC, shape, indptr, indices, data)
228 .map_err(|(_, _, _, e)| e)
229 .unwrap()
230 }
231
232 pub fn try_new(
236 shape: (usize, usize),
237 indptr: IptrStorage,
238 indices: IStorage,
239 data: DStorage,
240 ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)> {
241 Self::new_checked(CompressedStorage::CSR, shape, indptr, indices, data)
242 }
243
244 pub fn try_new_csc(
248 shape: (usize, usize),
249 indptr: IptrStorage,
250 indices: IStorage,
251 data: DStorage,
252 ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)> {
253 Self::new_checked(CompressedStorage::CSC, shape, indptr, indices, data)
254 }
255
256 pub unsafe fn new_unchecked(
266 storage: CompressedStorage,
267 shape: Shape,
268 indptr: IptrStorage,
269 indices: IStorage,
270 data: DStorage,
271 ) -> Self {
272 let (nrows, ncols) = shape;
273 Self {
274 storage,
275 nrows,
276 ncols,
277 indptr: crate::IndPtrBase::new_trusted(indptr),
278 indices,
279 data,
280 }
281 }
282
283 pub(crate) fn new_trusted(
286 storage: CompressedStorage,
287 shape: Shape,
288 indptr: IptrStorage,
289 indices: IStorage,
290 data: DStorage,
291 ) -> Self {
292 let (nrows, ncols) = shape;
293 Self {
294 storage,
295 nrows,
296 ncols,
297 indptr: crate::IndPtrBase::new_trusted(indptr),
298 indices,
299 data,
300 }
301 }
302}
303
304impl<N, I: SpIndex, Iptr: SpIndex, IptrStorage, IStorage, DStorage>
305 CsMatBase<N, I, IptrStorage, IStorage, DStorage, Iptr>
306where
307 IptrStorage: Deref<Target = [Iptr]>,
308 IStorage: DerefMut<Target = [I]>,
309 DStorage: DerefMut<Target = [N]>,
310{
311 fn new_from_unsorted_checked(
312 storage: CompressedStorage,
313 shape: (usize, usize),
314 indptr: IptrStorage,
315 mut indices: IStorage,
316 mut data: DStorage,
317 ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)>
318 where
319 N: Clone,
320 {
321 let (nrows, ncols) = shape;
322 let (inner, outer) = match storage {
323 CSR => (ncols, nrows),
324 CSC => (nrows, ncols),
325 };
326 if data.len() != indices.len() {
327 return Err((
328 indptr,
329 indices,
330 data,
331 StructureError::SizeMismatch(
332 "data and indices have different sizes",
333 ),
334 ));
335 }
336 let mut buf = Vec::new();
337 for start_stop in indptr.windows(2) {
338 let start = start_stop[0].to_usize().unwrap();
339 let stop = start_stop[1].to_usize().unwrap();
340 let indices = &mut indices[start..stop];
341 if utils::sorted_indices(indices) {
342 continue;
343 }
344 let data = &mut data[start..stop];
345 let len = stop - start;
346 let indices = &mut indices[..len];
347 let data = &mut data[..len];
348 utils::sort_indices_data_slices(indices, data, &mut buf);
349 }
350
351 match crate::sparse::utils::check_compressed_structure(
352 inner,
353 outer,
354 indptr.as_ref(),
355 indices.as_ref(),
356 ) {
357 Err(e) => Err((indptr, indices, data, e)),
358 Ok(_) => Ok(Self {
359 storage,
360 nrows,
361 ncols,
362 indptr: crate::IndPtrBase::new_trusted(indptr),
363 indices,
364 data,
365 }),
366 }
367 }
368
369 pub fn new_from_unsorted(
375 shape: Shape,
376 indptr: IptrStorage,
377 indices: IStorage,
378 data: DStorage,
379 ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)>
380 where
381 N: Clone,
382 {
383 Self::new_from_unsorted_checked(CSR, shape, indptr, indices, data)
384 }
385
386 pub fn new_from_unsorted_csc(
392 shape: Shape,
393 indptr: IptrStorage,
394 indices: IStorage,
395 data: DStorage,
396 ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)>
397 where
398 N: Clone,
399 {
400 Self::new_from_unsorted_checked(CSC, shape, indptr, indices, data)
401 }
402}
403
404impl<N, I: SpIndex, Iptr: SpIndex> CsMatI<N, I, Iptr> {
406 pub fn eye(dim: usize) -> Self
417 where
418 N: Num + Clone,
419 {
420 let _ = (I::from_usize(dim), Iptr::from_usize(dim)); let n = dim;
422 let indptr = (0..=n).map(Iptr::from_usize_unchecked).collect();
423 let indices = (0..n).map(I::from_usize_unchecked).collect();
424 let data = vec![N::one(); n];
425 Self::new_trusted(CSR, (n, n), indptr, indices, data)
426 }
427
428 pub fn eye_csc(dim: usize) -> Self
439 where
440 N: Num + Clone,
441 {
442 let _ = (I::from_usize(dim), Iptr::from_usize(dim)); let n = dim;
444 let indptr = (0..=n).map(Iptr::from_usize_unchecked).collect();
445 let indices = (0..n).map(I::from_usize_unchecked).collect();
446 let data = vec![N::one(); n];
447 Self::new_trusted(CSC, (n, n), indptr, indices, data)
448 }
449 pub fn empty(storage: CompressedStorage, inner_size: usize) -> Self {
451 let shape = match storage {
452 CSR => (0, inner_size),
453 CSC => (inner_size, 0),
454 };
455 Self::new_trusted(
456 storage,
457 shape,
458 vec![Iptr::zero(); 1],
459 Vec::new(),
460 Vec::new(),
461 )
462 }
463
464 pub fn zero(shape: Shape) -> Self {
467 let (nrows, _ncols) = shape;
468 Self::new_trusted(
469 CSR,
470 shape,
471 vec![Iptr::zero(); nrows + 1],
472 Vec::new(),
473 Vec::new(),
474 )
475 }
476
477 pub fn reserve_outer_dim(&mut self, outer_dim_additional: usize) {
479 self.indptr.reserve(outer_dim_additional);
480 }
481
482 pub fn reserve_nnz(&mut self, nnz_additional: usize) {
484 self.indices.reserve(nnz_additional);
485 self.data.reserve(nnz_additional);
486 }
487
488 pub fn reserve_outer_dim_exact(&mut self, outer_dim_lim: usize) {
490 self.indptr.reserve_exact(outer_dim_lim + 1);
491 }
492
493 pub fn reserve_nnz_exact(&mut self, nnz_lim: usize) {
495 self.indices.reserve_exact(nnz_lim);
496 self.data.reserve_exact(nnz_lim);
497 }
498
499 pub fn csr_from_dense(m: ArrayView<N, Ix2>, epsilon: N) -> Self
503 where
504 N: Num + Clone + cmp::PartialOrd + Signed,
505 {
506 let epsilon = if epsilon > N::zero() {
507 epsilon
508 } else {
509 N::zero()
510 };
511 let nrows = m.shape()[0];
512 let ncols = m.shape()[1];
513
514 let mut indptr = vec![Iptr::zero(); nrows + 1];
515 let mut nnz = 0;
516 for (row, row_count) in m.outer_iter().zip(&mut indptr[1..]) {
517 nnz += row.iter().filter(|&x| x.abs() > epsilon).count();
518 *row_count = Iptr::from_usize(nnz);
519 }
520
521 let mut indices = Vec::with_capacity(nnz);
522 let mut data = Vec::with_capacity(nnz);
523 for row in m.outer_iter() {
524 for (col_ind, x) in row.iter().enumerate() {
525 if x.abs() > epsilon {
526 indices.push(I::from_usize(col_ind));
527 data.push(x.clone());
528 }
529 }
530 }
531 Self {
532 storage: CompressedStorage::CSR,
533 nrows,
534 ncols,
535 indptr: crate::IndPtr::new_trusted(indptr),
536 indices,
537 data,
538 }
539 }
540
541 pub fn csc_from_dense(m: ArrayView<N, Ix2>, epsilon: N) -> Self
545 where
546 N: Num + Clone + cmp::PartialOrd + Signed,
547 {
548 Self::csr_from_dense(m.reversed_axes(), epsilon).transpose_into()
549 }
550
551 pub fn append_outer(self, data: &[N]) -> Self
553 where
554 N: Clone + Zero,
555 {
556 unsafe {
558 self.append_outer_iter_unchecked(
559 data.iter()
560 .cloned()
561 .enumerate()
562 .filter(|(_, val)| !val.is_zero()),
563 )
564 }
565 }
566
567 pub fn append_outer_iter<Iter>(self, iter: Iter) -> Self
574 where
575 N: Zero,
576 Iter: IntoIterator<Item = (usize, N)>,
577 {
578 let iter = iter.into_iter();
579 unsafe {
580 self.append_outer_iter_unchecked(AssertOrderedIterator {
581 prev: None,
582 iter: iter.filter(|(_, val)| !val.is_zero()),
583 })
584 }
585 }
586
587 pub unsafe fn append_outer_iter_unchecked<Iter>(
596 mut self,
597 iter: Iter,
598 ) -> Self
599 where
600 Iter: IntoIterator<Item = (usize, N)>,
601 {
602 let iter = iter.into_iter();
603 if let (_, Some(nnz)) = iter.size_hint() {
604 self.reserve_nnz(nnz)
605 }
606 let mut nnz = self.nnz();
607 for (inner_ind, val) in iter {
608 self.indices.push(I::from_usize(inner_ind));
609 self.data.push(val);
610 nnz += 1;
611 }
612 if let Some(last_inner_ind) = self.indices.last() {
613 assert!(
614 last_inner_ind.index_unchecked() < self.inner_dims(),
615 "inner index out of range"
616 );
617 }
618 match self.storage {
619 CSR => self.nrows += 1,
620 CSC => self.ncols += 1,
621 }
622 self.indptr.push(Iptr::from_usize(nnz));
623 self
624 }
625
626 pub fn append_outer_csvec(self, vec: CsVecViewI<N, I>) -> Self
628 where
629 N: Clone,
630 {
631 assert_eq!(self.inner_dims(), vec.dim());
632 unsafe {
634 self.append_outer_iter_unchecked(
635 vec.iter().map(|(i, val)| (i, val.clone())),
636 )
637 }
638 }
639
640 pub fn insert(&mut self, row: usize, col: usize, val: N) {
650 match self.storage() {
651 CSR => self.insert_outer_inner(row, col, val),
652 CSC => self.insert_outer_inner(col, row, val),
653 }
654 }
655
656 fn insert_outer_inner(
657 &mut self,
658 outer_ind: usize,
659 inner_ind: usize,
660 val: N,
661 ) {
662 let outer_dims = self.outer_dims();
663 let inner_ind_idx = I::from_usize(inner_ind);
664 if outer_ind >= outer_dims {
665 let last_nnz = self.indptr.nnz_i();
667 self.indptr.resize(outer_ind + 1, last_nnz);
668 self.set_outer_dims(outer_ind + 1);
669 self.indptr.push(last_nnz + Iptr::one());
670 self.indices.push(inner_ind_idx);
671 self.data.push(val);
672 } else {
673 let range = self.indptr.outer_inds_sz(outer_ind);
675 let location =
676 self.indices[range.clone()].binary_search(&inner_ind_idx);
677 match location {
678 Ok(ind) => {
679 let ind = range.start + ind.index_unchecked();
680 self.data[ind] = val;
681 return;
682 }
683 Err(ind) => {
684 let ind = range.start + ind.index_unchecked();
685 self.indices.insert(ind, inner_ind_idx);
686 self.data.insert(ind, val);
687 self.indptr.record_new_element(outer_ind);
688 }
689 }
690 }
691
692 if inner_ind >= self.inner_dims() {
693 self.set_inner_dims(inner_ind + 1);
694 }
695 }
696
697 fn set_outer_dims(&mut self, outer_dims: usize) {
698 match self.storage() {
699 CSR => self.nrows = outer_dims,
700 CSC => self.ncols = outer_dims,
701 }
702 }
703
704 fn set_inner_dims(&mut self, inner_dims: usize) {
705 match self.storage() {
706 CSR => self.ncols = inner_dims,
707 CSC => self.nrows = inner_dims,
708 }
709 }
710}
711
712pub(crate) struct AssertOrderedIterator<Iter> {
713 prev: Option<usize>,
714 iter: Iter,
715}
716
717impl<N, Iter: Iterator<Item = (usize, N)>> Iterator
718 for AssertOrderedIterator<Iter>
719{
720 type Item = (usize, N);
721
722 fn next(&mut self) -> Option<Self::Item> {
723 let (idx, n) = self.iter.next()?;
724
725 if let Some(prev_idx) = self.prev {
726 assert!(
727 prev_idx < idx,
728 "index out of order. {} followed {}",
729 idx,
730 prev_idx
731 );
732 }
733 self.prev = Some(idx);
734 Some((idx, n))
735 }
736
737 fn size_hint(&self) -> (usize, Option<usize>) {
738 self.iter.size_hint()
739 }
740}
741
742impl<'a, N: 'a, I: 'a + SpIndex, Iptr: 'a + SpIndex>
747 CsMatViewI<'a, N, I, Iptr>
748{
749 #[deprecated(
758 since = "0.10.0",
759 note = "Please use the `slice_outer` method instead"
760 )]
761 pub fn middle_outer_views(
762 &self,
763 i: usize,
764 count: usize,
765 ) -> CsMatViewI<'a, N, I, Iptr> {
766 let iend = i.checked_add(count).unwrap();
767 let (nrows, ncols) = match self.storage {
768 CSR => (count, self.cols()),
769 CSC => (self.rows(), count),
770 };
771 let data_range = self.indptr.outer_inds_slice(i, iend);
772 CsMatViewI {
773 storage: self.storage,
774 nrows,
775 ncols,
776 indptr: self.indptr.middle_slice_rbr(i..iend),
777 indices: &self.indices[data_range.clone()],
778 data: &self.data[data_range],
779 }
780 }
781
782 pub fn iter_rbr(&self) -> CsIter<'a, N, I, Iptr> {
788 CsIter {
789 storage: self.storage,
790 cur_outer: I::zero(),
791 indptr: self.indptr.reborrow(),
792 inner_iter: self.indices.iter().zip(self.data.iter()).enumerate(),
793 }
794 }
795}
796
797impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
799 CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr>
800where
801 I: SpIndex,
802 Iptr: SpIndex,
803 IptrStorage: Deref<Target = [Iptr]>,
804 IndStorage: Deref<Target = [I]>,
805 DataStorage: Deref<Target = [N]>,
806{
807 pub fn storage(&self) -> CompressedStorage {
809 self.storage
810 }
811
812 pub fn rows(&self) -> usize {
814 self.nrows
815 }
816
817 pub fn cols(&self) -> usize {
819 self.ncols
820 }
821
822 pub fn shape(&self) -> Shape {
825 (self.nrows, self.ncols)
826 }
827
828 pub fn nnz(&self) -> usize {
832 self.indptr.nnz()
833 }
834
835 pub fn density(&self) -> f64 {
838 let rows = self.nrows as f64;
839 let cols = self.ncols as f64;
840 let nnz = self.nnz() as f64;
841 nnz / (rows * cols)
842 }
843
844 pub fn outer_dims(&self) -> usize {
847 outer_dimension(self.storage, self.nrows, self.ncols)
848 }
849
850 pub fn inner_dims(&self) -> usize {
853 match self.storage {
854 CSC => self.nrows,
855 CSR => self.ncols,
856 }
857 }
858
859 pub fn get(&self, i: usize, j: usize) -> Option<&N> {
867 match self.storage {
868 CSR => self.get_outer_inner(i, j),
869 CSC => self.get_outer_inner(j, i),
870 }
871 }
872
873 pub fn indptr(&self) -> crate::IndPtrView<'_, Iptr> {
893 crate::IndPtrView::new_trusted(self.indptr.raw_storage())
894 }
895
896 pub fn proper_indptr(&self) -> std::borrow::Cow<'_, [Iptr]> {
920 self.indptr.to_proper()
921 }
922
923 pub fn indices(&self) -> &[I] {
926 &self.indices[..]
927 }
928
929 pub fn data(&self) -> &[N] {
932 &self.data[..]
933 }
934
935 pub fn into_raw_storage(self) -> (IptrStorage, IndStorage, DataStorage) {
947 let Self {
948 indptr,
949 indices,
950 data,
951 ..
952 } = self;
953 (indptr.into_raw_storage(), indices, data)
954 }
955
956 pub fn is_csc(&self) -> bool {
958 self.storage == CSC
959 }
960
961 pub fn is_csr(&self) -> bool {
963 self.storage == CSR
964 }
965
966 pub fn transpose_mut(&mut self) {
969 mem::swap(&mut self.nrows, &mut self.ncols);
970 self.storage = self.storage.other_storage();
971 }
972
973 pub fn transpose_into(mut self) -> Self {
976 self.transpose_mut();
977 self
978 }
979
980 pub fn transpose_view(&self) -> CsMatViewI<'_, N, I, Iptr> {
983 CsMatViewI {
984 storage: self.storage.other_storage(),
985 nrows: self.ncols,
986 ncols: self.nrows,
987 indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
988 indices: &self.indices[..],
989 data: &self.data[..],
990 }
991 }
992
993 pub fn to_owned(&self) -> CsMatI<N, I, Iptr>
996 where
997 N: Clone,
998 {
999 CsMatI {
1000 storage: self.storage,
1001 nrows: self.nrows,
1002 ncols: self.ncols,
1003 indptr: self.indptr.to_owned(),
1004 indices: self.indices.to_vec(),
1005 data: self.data.to_vec(),
1006 }
1007 }
1008
1009 pub fn to_inner_onehot(&self) -> CsMatI<N, I, Iptr>
1018 where
1019 N: Clone + Float + PartialOrd,
1020 {
1021 let mut indptr_counter = 0_usize;
1022 let mut indptr: Vec<Iptr> = Vec::with_capacity(self.indptr.len());
1023
1024 let max_data_len = self.indptr.len().min(self.data.len());
1025 let mut indices: Vec<I> = Vec::with_capacity(max_data_len);
1026 let mut data = Vec::with_capacity(max_data_len);
1027
1028 for inner_vec in self.outer_iterator() {
1029 let hot_element = inner_vec
1030 .iter()
1031 .filter(|e| !e.1.is_nan())
1032 .max_by(|a, b| {
1033 a.1.partial_cmp(b.1)
1034 .expect("Unexpected NaN value was found")
1035 })
1036 .map(|a| a.0);
1037
1038 indptr.push(Iptr::from_usize(indptr_counter));
1039
1040 if let Some(inner_id) = hot_element {
1041 indices.push(I::from_usize(inner_id));
1042 data.push(N::one());
1043 indptr_counter += 1;
1044 }
1045 }
1046
1047 indptr.push(Iptr::from_usize(indptr_counter));
1048 CsMatI {
1049 storage: self.storage,
1050 nrows: self.rows(),
1051 ncols: self.cols(),
1052 indptr: crate::IndPtr::new_trusted(indptr),
1053 indices,
1054 data,
1055 }
1056 }
1057
1058 pub fn to_other_types<I2, N2, Iptr2>(&self) -> CsMatI<N2, I2, Iptr2>
1065 where
1066 N: Clone + Into<N2>,
1067 I2: SpIndex,
1068 Iptr2: SpIndex,
1069 {
1070 let indptr = crate::IndPtr::new_trusted(
1071 self.indptr
1072 .raw_storage()
1073 .iter()
1074 .map(|i| Iptr2::from_usize(i.index_unchecked()))
1075 .collect(),
1076 );
1077 let indices = self
1078 .indices
1079 .iter()
1080 .map(|i| I2::from_usize(i.index_unchecked()))
1081 .collect();
1082 let data = self.data.iter().map(|x| x.clone().into()).collect();
1083 CsMatI {
1084 storage: self.storage,
1085 nrows: self.nrows,
1086 ncols: self.ncols,
1087 indptr,
1088 indices,
1089 data,
1090 }
1091 }
1092
1093 pub fn view(&self) -> CsMatViewI<'_, N, I, Iptr> {
1095 CsMatViewI {
1096 storage: self.storage,
1097 nrows: self.nrows,
1098 ncols: self.ncols,
1099 indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
1100 indices: &self.indices[..],
1101 data: &self.data[..],
1102 }
1103 }
1104
1105 pub fn structure_view(&self) -> CsStructureViewI<'_, I, Iptr> {
1106 let zst_data = unsafe {
1112 std::slice::from_raw_parts(
1113 self.data.as_ptr().cast::<()>(),
1114 self.data.len(),
1115 )
1116 };
1117 CsStructureViewI {
1118 storage: self.storage,
1119 nrows: self.nrows,
1120 ncols: self.ncols,
1121 indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
1122 indices: &self.indices[..],
1123 data: zst_data,
1124 }
1125 }
1126
1127 pub fn to_dense(&self) -> Array<N, Ix2>
1128 where
1129 N: Clone + Zero,
1130 {
1131 let mut res = Array::zeros((self.rows(), self.cols()));
1132 assign_to_dense(res.view_mut(), self.view());
1133 res
1134 }
1135
1136 pub fn outer_iterator(
1151 &self,
1152 ) -> impl std::iter::DoubleEndedIterator<Item = CsVecViewI<'_, N, I>>
1153 + std::iter::ExactSizeIterator<Item = CsVecViewI<'_, N, I>>
1154 + '_ {
1155 self.indptr.iter_outer_sz().map(move |range| {
1156 CsVecViewI::new_trusted(
1157 self.inner_dims(),
1158 &self.indices[range.clone()],
1160 &self.data[range],
1161 )
1162 })
1163 }
1164
1165 #[doc(hidden)]
1170 pub fn outer_iterator_papt<'a, 'perm: 'a>(
1171 &'a self,
1172 perm: PermViewI<'perm, I>,
1173 ) -> impl std::iter::DoubleEndedIterator<Item = (usize, CsVecViewI<'a, N, I>)>
1174 + std::iter::ExactSizeIterator<Item = (usize, CsVecViewI<'a, N, I>)>
1175 + 'a {
1176 (0..self.outer_dims()).map(move |outer_ind| {
1177 let outer_ind_perm = perm.at(outer_ind);
1178 let range = self.indptr.outer_inds_sz(outer_ind_perm);
1179 let indices = &self.indices[range.clone()];
1180 let data = &self.data[range];
1181 let vec = CsVecBase::new_trusted(self.inner_dims(), indices, data);
1183 (outer_ind_perm, vec)
1184 })
1185 }
1186
1187 pub fn max_outer_nnz(&self) -> usize {
1189 self.outer_iterator()
1190 .map(|outer| outer.indices().len())
1191 .max()
1192 .unwrap_or(0)
1193 }
1194
1195 pub fn degrees(&self) -> Vec<usize> {
1206 self.outer_iterator()
1207 .enumerate()
1208 .map(|(outer_dim, outer)| {
1209 outer
1210 .indices()
1211 .iter()
1212 .filter(|ind| ind.index() != outer_dim)
1213 .count()
1214 })
1215 .collect()
1216 }
1217
1218 pub fn outer_view(&self, i: usize) -> Option<CsVecViewI<'_, N, I>> {
1220 if i >= self.outer_dims() {
1221 return None;
1222 }
1223 let range = self.indptr.outer_inds_sz(i);
1224 Some(CsVecViewI::new_trusted(
1226 self.inner_dims(),
1227 &self.indices[range.clone()],
1229 &self.data[range],
1230 ))
1231 }
1232
1233 pub fn diag(&self) -> CsVecI<N, I>
1235 where
1236 N: Clone,
1237 {
1238 let shape = self.shape();
1239 let smallest_dim: usize = cmp::min(shape.0, shape.1);
1240 let heuristic = smallest_dim / 2;
1243 let mut index_vec = Vec::with_capacity(heuristic);
1244 let mut data_vec = Vec::with_capacity(heuristic);
1245
1246 for i in 0..smallest_dim {
1247 let optional_index = self.nnz_index(i, i);
1248 if let Some(idx) = optional_index {
1249 data_vec.push(self[idx].clone());
1250 index_vec.push(I::from_usize(i));
1251 }
1252 }
1253 data_vec.shrink_to_fit();
1254 index_vec.shrink_to_fit();
1255 CsVecI::new_trusted(smallest_dim, index_vec, data_vec)
1256 }
1257
1258 pub fn diag_iter(
1260 &self,
1261 ) -> impl ExactSizeIterator<Item = Option<&N>>
1262 + DoubleEndedIterator<Item = Option<&N>> {
1263 let smallest_dim = cmp::min(self.ncols, self.nrows);
1264 (0..smallest_dim).map(move |i| self.get_outer_inner(i, i))
1265 }
1266
1267 pub fn outer_block_iter(
1273 &self,
1274 block_size: usize,
1275 ) -> impl std::iter::DoubleEndedIterator<Item = CsMatViewI<'_, N, I, Iptr>>
1276 + std::iter::ExactSizeIterator<Item = CsMatViewI<'_, N, I, Iptr>>
1277 + '_ {
1278 (0..self.outer_dims()).step_by(block_size).map(move |i| {
1279 let count = if i + block_size > self.outer_dims() {
1280 self.outer_dims() - i
1281 } else {
1282 block_size
1283 };
1284 self.view().slice_outer_rbr(i..i + count)
1285 })
1286 }
1287
1288 pub fn map<F, N2>(&self, f: F) -> CsMatI<N2, I, Iptr>
1290 where
1291 F: FnMut(&N) -> N2,
1292 {
1293 let data: Vec<N2> = self.data.iter().map(f).collect();
1294
1295 CsMatI {
1296 storage: self.storage,
1297 nrows: self.nrows,
1298 ncols: self.ncols,
1299 indptr: self.indptr.to_owned(),
1300 indices: self.indices.to_vec(),
1301 data,
1302 }
1303 }
1304
1305 pub fn get_outer_inner(
1313 &self,
1314 outer_ind: usize,
1315 inner_ind: usize,
1316 ) -> Option<&N> {
1317 self.outer_view(outer_ind)
1318 .and_then(|vec| vec.get_rbr(inner_ind))
1319 }
1320
1321 pub fn nnz_index(&self, row: usize, col: usize) -> Option<NnzIndex> {
1328 match self.storage() {
1329 CSR => self.nnz_index_outer_inner(row, col),
1330 CSC => self.nnz_index_outer_inner(col, row),
1331 }
1332 }
1333
1334 pub fn nnz_index_outer_inner(
1340 &self,
1341 outer_ind: usize,
1342 inner_ind: usize,
1343 ) -> Option<NnzIndex> {
1344 if outer_ind >= self.outer_dims() {
1345 return None;
1346 }
1347 let offset = self.indptr.outer_inds_sz(outer_ind).start;
1348 self.outer_view(outer_ind)
1349 .and_then(|vec| vec.nnz_index(inner_ind))
1350 .map(|vec::NnzIndex(ind)| NnzIndex(ind + offset))
1351 }
1352
1353 pub fn check_compressed_structure(&self) -> Result<(), StructureError> {
1363 let inner = self.inner_dims();
1364 let outer = self.outer_dims();
1365
1366 if self.indices.len() != self.data.len() {
1367 return Err(StructureError::SizeMismatch(
1368 "Indices and data lengths do not match",
1369 ));
1370 }
1371
1372 utils::check_compressed_structure(
1373 inner,
1374 outer,
1375 self.indptr.raw_storage(),
1376 &self.indices,
1377 )
1378 }
1379
1380 pub fn iter(&self) -> CsIter<'_, N, I, Iptr> {
1383 CsIter {
1384 storage: self.storage,
1385 cur_outer: I::zero(),
1386 indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
1387 inner_iter: self.indices.iter().zip(self.data.iter()).enumerate(),
1388 }
1389 }
1390}
1391
1392impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
1394 CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr>
1395where
1396 N: Default,
1397 I: SpIndex,
1398 Iptr: SpIndex,
1399 IptrStorage: Deref<Target = [Iptr]>,
1400 IndStorage: Deref<Target = [I]>,
1401 DataStorage: Deref<Target = [N]>,
1402{
1403 pub fn to_other_storage(&self) -> CsMatI<N, I, Iptr>
1406 where
1407 N: Clone,
1408 {
1409 let mut indptr = vec![Iptr::zero(); self.inner_dims() + 1];
1410 let mut indices = vec![I::zero(); self.nnz()];
1411 let mut data = vec![N::default(); self.nnz()];
1412 raw::convert_mat_storage(
1413 self.view(),
1414 &mut indptr,
1415 &mut indices,
1416 &mut data,
1417 );
1418 CsMatI {
1419 storage: self.storage().other_storage(),
1420 nrows: self.nrows,
1421 ncols: self.ncols,
1422 indptr: crate::IndPtr::new_trusted(indptr),
1423 indices,
1424 data,
1425 }
1426 }
1427
1428 pub fn to_csc(&self) -> CsMatI<N, I, Iptr>
1431 where
1432 N: Clone,
1433 {
1434 match self.storage {
1435 CSR => self.to_other_storage(),
1436 CSC => self.to_owned(),
1437 }
1438 }
1439
1440 pub fn to_csr(&self) -> CsMatI<N, I, Iptr>
1443 where
1444 N: Clone,
1445 {
1446 match self.storage {
1447 CSR => self.to_owned(),
1448 CSC => self.to_other_storage(),
1449 }
1450 }
1451}
1452
1453impl<N, I, Iptr> CsMatI<N, I, Iptr>
1454where
1455 N: Default,
1456
1457 I: SpIndex,
1458 Iptr: SpIndex,
1459{
1460 pub fn into_csc(self) -> Self
1464 where
1465 N: Clone,
1466 {
1467 match self.storage {
1468 CSR => self.to_other_storage(),
1469 CSC => self,
1470 }
1471 }
1472
1473 pub fn into_csr(self) -> Self
1477 where
1478 N: Clone,
1479 {
1480 match self.storage {
1481 CSR => self,
1482 CSC => self.to_other_storage(),
1483 }
1484 }
1485}
1486
1487impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
1489 CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr>
1490where
1491 I: SpIndex,
1492 Iptr: SpIndex,
1493 IptrStorage: Deref<Target = [Iptr]>,
1494 IndStorage: Deref<Target = [I]>,
1495 DataStorage: DerefMut<Target = [N]>,
1496{
1497 pub fn data_mut(&mut self) -> &mut [N] {
1503 &mut self.data[..]
1504 }
1505
1506 pub fn scale(&mut self, val: N)
1508 where
1509 for<'r> N: MulAssign<&'r N>,
1510 {
1511 for data in self.data_mut() {
1512 *data *= &val;
1513 }
1514 }
1515
1516 pub fn outer_view_mut(
1519 &mut self,
1520 i: usize,
1521 ) -> Option<CsVecViewMutI<'_, N, I>> {
1522 if i >= self.outer_dims() {
1523 return None;
1524 }
1525 let range = self.indptr.outer_inds_sz(i);
1526 Some(CsVecBase::new_trusted(
1528 self.inner_dims(),
1529 &self.indices[range.clone()],
1530 &mut self.data[range],
1531 ))
1532 }
1533
1534 pub fn get_mut(&mut self, i: usize, j: usize) -> Option<&mut N> {
1543 match self.storage {
1544 CSR => self.get_outer_inner_mut(i, j),
1545 CSC => self.get_outer_inner_mut(j, i),
1546 }
1547 }
1548
1549 pub fn get_outer_inner_mut(
1557 &mut self,
1558 outer_ind: usize,
1559 inner_ind: usize,
1560 ) -> Option<&mut N> {
1561 if let Some(NnzIndex(index)) =
1562 self.nnz_index_outer_inner(outer_ind, inner_ind)
1563 {
1564 Some(&mut self.data[index])
1565 } else {
1566 None
1567 }
1568 }
1569
1570 pub fn set(&mut self, row: usize, col: usize, val: N) {
1577 let outer = outer_dimension(self.storage(), row, col);
1578 let inner = inner_dimension(self.storage(), row, col);
1579 let vec::NnzIndex(index) = self
1580 .outer_view(outer)
1581 .and_then(|vec| vec.nnz_index(inner))
1582 .unwrap();
1583 self.data[index] = val;
1584 }
1585
1586 pub fn map_inplace<F>(&mut self, mut f: F)
1588 where
1589 F: FnMut(&N) -> N,
1590 {
1591 for val in &mut self.data[..] {
1592 *val = f(val);
1593 }
1594 }
1595
1596 pub fn outer_iterator_mut(
1602 &mut self,
1603 ) -> impl std::iter::DoubleEndedIterator<Item = CsVecViewMutI<'_, N, I>>
1604 + std::iter::ExactSizeIterator<Item = CsVecViewMutI<'_, N, I>>
1605 + '_ {
1606 let inner_dim = self.inner_dims();
1607 let indices = &self.indices[..];
1608 let data_ptr: *mut N = self.data.as_mut_ptr();
1609 self.indptr.iter_outer_sz().map(move |range| {
1610 let data: &mut [N] = unsafe {
1614 std::slice::from_raw_parts_mut(
1615 data_ptr.add(range.start),
1616 range.end - range.start,
1617 )
1618 };
1619
1620 CsVecViewMutI::new_trusted(inner_dim, &indices[range], data)
1621 })
1622 }
1623
1624 pub fn view_mut(&mut self) -> CsMatViewMutI<'_, N, I, Iptr> {
1626 CsMatViewMutI {
1627 storage: self.storage,
1628 nrows: self.nrows,
1629 ncols: self.ncols,
1630 indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
1631 indices: &self.indices[..],
1632 data: &mut self.data[..],
1633 }
1634 }
1635
1636 pub fn diag_iter_mut(
1638 &mut self,
1639 ) -> impl ExactSizeIterator<Item = Option<&mut N>>
1640 + DoubleEndedIterator<Item = Option<&mut N>>
1641 + '_ {
1642 let data_ptr: *mut N = self.data[..].as_mut_ptr();
1643 let smallest_dim = cmp::min(self.ncols, self.nrows);
1644 (0..smallest_dim).map(move |i| {
1645 let idx = self.nnz_index_outer_inner(i, i);
1646 if let Some(NnzIndex(idx)) = idx {
1647 Some(unsafe { &mut *data_ptr.add(idx) })
1656 } else {
1657 None
1658 }
1659 })
1660 }
1661}
1662
1663impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
1664 CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr>
1665where
1666 I: SpIndex,
1667 Iptr: SpIndex,
1668 IptrStorage: DerefMut<Target = [Iptr]>,
1669 IndStorage: DerefMut<Target = [I]>,
1670 DataStorage: DerefMut<Target = [N]>,
1671{
1672 pub fn modify<F>(&mut self, mut f: F)
1706 where
1707 F: FnMut(&mut [Iptr], &mut [I], &mut [N]),
1708 {
1709 f(
1710 self.indptr.raw_storage_mut(),
1711 &mut self.indices[..],
1712 &mut self.data[..],
1713 );
1714 self.check_compressed_structure().unwrap();
1718 }
1719}
1720
1721pub mod raw {
1723 use crate::indexing::SpIndex;
1724 use crate::sparse::prelude::*;
1725 use std::mem::swap;
1726
1727 pub fn convert_mat_storage<N: Clone, I: SpIndex, Iptr: SpIndex>(
1783 mat: CsMatViewI<N, I, Iptr>,
1784 indptr: &mut [Iptr],
1785 indices: &mut [I],
1786 data: &mut [N],
1787 ) {
1788 assert_eq!(indptr.len(), mat.inner_dims() + 1);
1789 assert_eq!(indices.len(), mat.indices().len());
1790 assert_eq!(data.len(), mat.data().len());
1791
1792 assert!(indptr.iter().all(num_traits::Zero::is_zero));
1793
1794 assert!(
1795 I::try_from_usize(mat.rows()).is_some(),
1796 "Index type is not large enough to hold the number of rows requested (I::max_value={:?} vs. required {})", I::max_value(), mat.rows(),
1797 );
1798
1799 for vec in mat.outer_iterator() {
1800 for (inner_dim, _) in vec.iter() {
1801 indptr[inner_dim] += Iptr::one();
1802 }
1803 }
1804
1805 let mut cumsum = Iptr::zero();
1806 for iptr in indptr.iter_mut() {
1807 let tmp = *iptr;
1808 *iptr = cumsum;
1809 cumsum += tmp;
1810 }
1811 if let Some(last_iptr) = indptr.last() {
1812 assert_eq!(last_iptr.index(), mat.nnz());
1813 }
1814
1815 for (outer_dim, vec) in mat.outer_iterator().enumerate() {
1816 let outer_dim = I::from_usize_unchecked(outer_dim);
1817 for (inner_dim, val) in vec.iter() {
1818 let dest = indptr[inner_dim].index();
1819 data[dest] = val.clone();
1820 indices[dest] = outer_dim;
1821 indptr[inner_dim] += Iptr::one();
1822 }
1823 }
1824
1825 let mut last = Iptr::zero();
1826 for iptr in indptr.iter_mut() {
1827 swap(iptr, &mut last);
1828 }
1829 }
1830}
1831
1832impl<I, Iptr, IpStorage, IStorage, DStorage, T> std::ops::MulAssign<T>
1833 for CsMatBase<T, I, IpStorage, IStorage, DStorage, Iptr>
1834where
1835 I: SpIndex,
1836 Iptr: SpIndex,
1837 IpStorage: Deref<Target = [Iptr]>,
1838 IStorage: Deref<Target = [I]>,
1839 DStorage: DerefMut<Target = [T]>,
1840 T: std::ops::MulAssign<T> + Clone,
1841{
1842 fn mul_assign(&mut self, rhs: T) {
1843 self.data_mut()
1844 .iter_mut()
1845 .for_each(|v| v.mul_assign(rhs.clone()));
1846 }
1847}
1848
1849impl<I, Iptr, IpStorage, IStorage, DStorage, T> std::ops::DivAssign<T>
1850 for CsMatBase<T, I, IpStorage, IStorage, DStorage, Iptr>
1851where
1852 I: SpIndex,
1853 Iptr: SpIndex,
1854 IpStorage: Deref<Target = [Iptr]>,
1855 IStorage: Deref<Target = [I]>,
1856 DStorage: DerefMut<Target = [T]>,
1857 T: std::ops::DivAssign<T> + Clone,
1858{
1859 fn div_assign(&mut self, rhs: T) {
1860 self.data_mut()
1861 .iter_mut()
1862 .for_each(|v| v.div_assign(rhs.clone()));
1863 }
1864}
1865
1866impl<'a, 'b, N, I, Iptr, IpS1, IS1, DS1, IpS2, IS2, DS2>
1867 Mul<&'b CsMatBase<N, I, IpS2, IS2, DS2, Iptr>>
1868 for &'a CsMatBase<N, I, IpS1, IS1, DS1, Iptr>
1869where
1870 N: 'a + Clone + crate::MulAcc + num_traits::Zero + Default + Send + Sync,
1871 I: 'a + SpIndex,
1872 Iptr: 'a + SpIndex,
1873 IpS1: 'a + Deref<Target = [Iptr]>,
1874 IS1: 'a + Deref<Target = [I]>,
1875 DS1: 'a + Deref<Target = [N]>,
1876 IpS2: 'b + Deref<Target = [Iptr]>,
1877 IS2: 'b + Deref<Target = [I]>,
1878 DS2: 'b + Deref<Target = [N]>,
1879{
1880 type Output = CsMatI<N, I, Iptr>;
1881
1882 fn mul(
1883 self,
1884 rhs: &'b CsMatBase<N, I, IpS2, IS2, DS2, Iptr>,
1885 ) -> CsMatI<N, I, Iptr> {
1886 csmat_mul_csmat(self, rhs)
1887 }
1888}
1889
1890pub fn csmat_mul_csmat<
1896 'a,
1897 'b,
1898 N,
1899 A,
1900 B,
1901 I,
1902 Iptr,
1903 IpS1,
1904 IS1,
1905 DS1,
1906 IpS2,
1907 IS2,
1908 DS2,
1909>(
1910 lhs: &'a CsMatBase<A, I, IpS1, IS1, DS1, Iptr>,
1911 rhs: &'b CsMatBase<B, I, IpS2, IS2, DS2, Iptr>,
1912) -> CsMatI<N, I, Iptr>
1913where
1914 N: 'a
1915 + Clone
1916 + crate::MulAcc<A, B>
1917 + crate::MulAcc<B, A>
1918 + num_traits::Zero
1919 + Default
1920 + Send
1921 + Sync,
1922 A: 'a + Clone + num_traits::Zero + Default + Send + Sync,
1923 B: 'a + Clone + num_traits::Zero + Default + Send + Sync,
1924 I: 'a + SpIndex,
1925 Iptr: 'a + SpIndex,
1926 IpS1: 'a + Deref<Target = [Iptr]>,
1927 IS1: 'a + Deref<Target = [I]>,
1928 DS1: 'a + Deref<Target = [A]>,
1929 IpS2: 'b + Deref<Target = [Iptr]>,
1930 IS2: 'b + Deref<Target = [I]>,
1931 DS2: 'b + Deref<Target = [B]>,
1932{
1933 match (lhs.storage(), rhs.storage()) {
1934 (CSR, CSR) => smmp::mul_csr_csr(lhs.view(), rhs.view()),
1935 (CSR, CSC) => {
1936 let rhs_csr = rhs.to_other_storage();
1937 smmp::mul_csr_csr(lhs.view(), rhs_csr.view())
1938 }
1939 (CSC, CSR) => {
1940 let rhs_csc = rhs.to_other_storage();
1941 smmp::mul_csr_csr(rhs_csc.transpose_view(), lhs.transpose_view())
1942 .transpose_into()
1943 }
1944 (CSC, CSC) => {
1945 smmp::mul_csr_csr(rhs.transpose_view(), lhs.transpose_view())
1946 .transpose_into()
1947 }
1948 }
1949}
1950
1951impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Add<&'b ArrayBase<DS2, Ix2>>
1952 for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
1953where
1954 N: 'a + Copy + Num + Default,
1955 for<'r> &'r N: Mul<Output = N>,
1956 I: 'a + SpIndex,
1957 Iptr: 'a + SpIndex,
1958 IpS: 'a + Deref<Target = [Iptr]>,
1959 IS: 'a + Deref<Target = [I]>,
1960 DS: 'a + Deref<Target = [N]>,
1961 DS2: 'b + ndarray::Data<Elem = N>,
1962{
1963 type Output = Array<N, Ix2>;
1964
1965 fn add(self, rhs: &'b ArrayBase<DS2, Ix2>) -> Array<N, Ix2> {
1966 let is_standard_layout =
1967 utils::fastest_axis(rhs.view()) == ndarray::Axis(1);
1968 let neuter_element = N::one();
1969 match (self.storage(), is_standard_layout) {
1970 (CSR, true) | (CSC, false) => binop::add_dense_mat_same_ordering(
1971 self,
1972 rhs,
1973 neuter_element,
1974 neuter_element,
1975 ),
1976 (CSR, false) | (CSC, true) => {
1977 let lhs = self.to_other_storage();
1978 binop::add_dense_mat_same_ordering(
1979 &lhs,
1980 rhs,
1981 neuter_element,
1982 neuter_element,
1983 )
1984 }
1985 }
1986 }
1987}
1988
1989impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Mul<&'b ArrayBase<DS2, Ix2>>
1990 for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
1991where
1992 N: 'a + crate::MulAcc + num_traits::Zero + Clone,
1993 I: 'a + SpIndex,
1994 Iptr: 'a + SpIndex,
1995 IpS: 'a + Deref<Target = [Iptr]>,
1996 IS: 'a + Deref<Target = [I]>,
1997 DS: 'a + Deref<Target = [N]>,
1998 DS2: 'b + ndarray::Data<Elem = N>,
1999{
2000 type Output = Array<N, Ix2>;
2001
2002 fn mul(self, rhs: &'b ArrayBase<DS2, Ix2>) -> Array<N, Ix2> {
2003 let rows = self.rows();
2004 let cols = rhs.shape()[1];
2005 match (self.storage(), cols >= 8) {
2010 (CSR, true) => {
2011 let mut res = Array::zeros((rows, cols));
2012 prod::csr_mulacc_dense_rowmaj(
2013 self.view(),
2014 rhs.view(),
2015 res.view_mut(),
2016 );
2017 res
2018 }
2019 (CSR, false) => {
2020 let mut res = Array::zeros((rows, cols).f());
2021 prod::csr_mulacc_dense_colmaj(
2022 self.view(),
2023 rhs.view(),
2024 res.view_mut(),
2025 );
2026 res
2027 }
2028 (CSC, true) => {
2029 let mut res = Array::zeros((rows, cols));
2030 prod::csc_mulacc_dense_rowmaj(
2031 self.view(),
2032 rhs.view(),
2033 res.view_mut(),
2034 );
2035 res
2036 }
2037 (CSC, false) => {
2038 let mut res = Array::zeros((rows, cols).f());
2039 prod::csc_mulacc_dense_colmaj(
2040 self.view(),
2041 rhs.view(),
2042 res.view_mut(),
2043 );
2044 res
2045 }
2046 }
2047 }
2048}
2049
2050impl<N, I, IpS, IS, DS, DS2> Dot<CsMatBase<N, I, IpS, IS, DS>>
2051 for ArrayBase<DS2, Ix2>
2052where
2053 N: Clone + crate::MulAcc + num_traits::Zero + std::fmt::Debug,
2054 I: SpIndex,
2055 IpS: Deref<Target = [I]>,
2056 IS: Deref<Target = [I]>,
2057 DS: Deref<Target = [N]>,
2058 DS2: ndarray::Data<Elem = N>,
2059{
2060 type Output = Array<N, Ix2>;
2061
2062 fn dot(&self, rhs: &CsMatBase<N, I, IpS, IS, DS>) -> Array<N, Ix2> {
2063 let rhs_t = rhs.transpose_view();
2064 let lhs_t = self.t();
2065
2066 let rows = rhs_t.rows();
2067 let cols = lhs_t.ncols();
2068 let rres = match (rhs_t.storage(), cols >= 8) {
2073 (CSR, true) => {
2074 let mut res = Array::zeros((rows, cols));
2075 prod::csr_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut());
2076 res.reversed_axes()
2077 }
2078 (CSR, false) => {
2079 let mut res = Array::zeros((rows, cols).f());
2080 prod::csr_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut());
2081 res.reversed_axes()
2082 }
2083 (CSC, true) => {
2084 let mut res = Array::zeros((rows, cols));
2085 prod::csc_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut());
2086 res.reversed_axes()
2087 }
2088 (CSC, false) => {
2089 let mut res = Array::zeros((rows, cols).f());
2090 prod::csc_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut());
2091 res.reversed_axes()
2092 }
2093 };
2094
2095 assert_eq!(self.shape()[0], rres.shape()[0]);
2096 assert_eq!(rhs.cols(), rres.shape()[1]);
2097 rres
2098 }
2099}
2100
2101impl<N, I, Iptr, IpS, IS, DS, DS2> Dot<ArrayBase<DS2, Ix2>>
2102 for CsMatBase<N, I, IpS, IS, DS, Iptr>
2103where
2104 N: Clone + crate::MulAcc + num_traits::Zero,
2105 I: SpIndex,
2106 Iptr: SpIndex,
2107 IpS: Deref<Target = [Iptr]>,
2108 IS: Deref<Target = [I]>,
2109 DS: Deref<Target = [N]>,
2110 DS2: ndarray::Data<Elem = N>,
2111{
2112 type Output = Array<N, Ix2>;
2113
2114 fn dot(&self, rhs: &ArrayBase<DS2, Ix2>) -> Array<N, Ix2> {
2115 Mul::mul(self, rhs)
2116 }
2117}
2118
2119impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Mul<&'b ArrayBase<DS2, Ix1>>
2120 for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
2121where
2122 N: 'a + Clone + crate::MulAcc + num_traits::Zero,
2123 I: 'a + SpIndex,
2124 Iptr: 'a + SpIndex,
2125 IpS: 'a + Deref<Target = [Iptr]>,
2126 IS: 'a + Deref<Target = [I]>,
2127 DS: 'a + Deref<Target = [N]>,
2128 DS2: 'b + ndarray::Data<Elem = N>,
2129{
2130 type Output = Array<N, Ix1>;
2131
2132 fn mul(self, rhs: &'b ArrayBase<DS2, Ix1>) -> Array<N, Ix1> {
2133 let rows = self.rows();
2134 let cols = rhs.shape()[0];
2135 #[allow(deprecated)]
2136 let rhs_reshape = rhs.view().into_shape((cols, 1)).unwrap();
2137 let mut res = Array::zeros(rows);
2138 {
2139 #[allow(deprecated)]
2140 let res_reshape = res.view_mut().into_shape((rows, 1)).unwrap();
2141 match self.storage() {
2142 CSR => {
2143 prod::csr_mulacc_dense_colmaj(
2144 self.view(),
2145 rhs_reshape,
2146 res_reshape,
2147 );
2148 }
2149 CSC => {
2150 prod::csc_mulacc_dense_colmaj(
2151 self.view(),
2152 rhs_reshape,
2153 res_reshape,
2154 );
2155 }
2156 }
2157 }
2158 res
2159 }
2160}
2161
2162impl<N, I, Iptr, IpS, IS, DS, DS2> Dot<ArrayBase<DS2, Ix1>>
2163 for CsMatBase<N, I, IpS, IS, DS, Iptr>
2164where
2165 N: Clone + crate::MulAcc + num_traits::Zero,
2166 I: SpIndex,
2167 Iptr: SpIndex,
2168 IpS: Deref<Target = [Iptr]>,
2169 IS: Deref<Target = [I]>,
2170 DS: Deref<Target = [N]>,
2171 DS2: ndarray::Data<Elem = N>,
2172{
2173 type Output = Array<N, Ix1>;
2174
2175 fn dot(&self, rhs: &ArrayBase<DS2, Ix1>) -> Array<N, Ix1> {
2176 Mul::mul(self, rhs)
2177 }
2178}
2179
2180impl<N, I, Iptr, IpS, IS, DS> Index<[usize; 2]>
2181 for CsMatBase<N, I, IpS, IS, DS, Iptr>
2182where
2183 I: SpIndex,
2184 Iptr: SpIndex,
2185 IpS: Deref<Target = [Iptr]>,
2186 IS: Deref<Target = [I]>,
2187 DS: Deref<Target = [N]>,
2188{
2189 type Output = N;
2190
2191 fn index(&self, index: [usize; 2]) -> &N {
2192 let i = index[0];
2193 let j = index[1];
2194 self.get(i, j).unwrap()
2195 }
2196}
2197
2198impl<N, I, Iptr, IpS, IS, DS> IndexMut<[usize; 2]>
2199 for CsMatBase<N, I, IpS, IS, DS, Iptr>
2200where
2201 I: SpIndex,
2202 Iptr: SpIndex,
2203 IpS: Deref<Target = [Iptr]>,
2204 IS: Deref<Target = [I]>,
2205 DS: DerefMut<Target = [N]>,
2206{
2207 fn index_mut(&mut self, index: [usize; 2]) -> &mut N {
2208 let i = index[0];
2209 let j = index[1];
2210 self.get_mut(i, j).unwrap()
2211 }
2212}
2213
2214impl<N, I, Iptr, IpS, IS, DS> Index<NnzIndex>
2215 for CsMatBase<N, I, IpS, IS, DS, Iptr>
2216where
2217 I: SpIndex,
2218 Iptr: SpIndex,
2219 IpS: Deref<Target = [Iptr]>,
2220 IS: Deref<Target = [I]>,
2221 DS: Deref<Target = [N]>,
2222{
2223 type Output = N;
2224
2225 fn index(&self, index: NnzIndex) -> &N {
2226 let NnzIndex(i) = index;
2227 self.data().get(i).unwrap()
2228 }
2229}
2230
2231impl<N, I, Iptr, IpS, IS, DS> IndexMut<NnzIndex>
2232 for CsMatBase<N, I, IpS, IS, DS, Iptr>
2233where
2234 I: SpIndex,
2235 Iptr: SpIndex,
2236 IpS: Deref<Target = [Iptr]>,
2237 IS: Deref<Target = [I]>,
2238 DS: DerefMut<Target = [N]>,
2239{
2240 fn index_mut(&mut self, index: NnzIndex) -> &mut N {
2241 let NnzIndex(i) = index;
2242 self.data_mut().get_mut(i).unwrap()
2243 }
2244}
2245
2246impl<N, I, Iptr, IpS, IS, DS> SparseMat for CsMatBase<N, I, IpS, IS, DS, Iptr>
2247where
2248 I: SpIndex,
2249 Iptr: SpIndex,
2250 IpS: Deref<Target = [Iptr]>,
2251 IS: Deref<Target = [I]>,
2252 DS: Deref<Target = [N]>,
2253{
2254 fn rows(&self) -> usize {
2255 self.rows()
2256 }
2257
2258 fn cols(&self) -> usize {
2259 self.cols()
2260 }
2261
2262 fn nnz(&self) -> usize {
2263 self.nnz()
2264 }
2265}
2266
2267impl<'a, N, I, Iptr, IpS, IS, DS> SparseMat
2268 for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
2269where
2270 I: 'a + SpIndex,
2271 Iptr: 'a + SpIndex,
2272 N: 'a,
2273 IpS: Deref<Target = [Iptr]>,
2274 IS: Deref<Target = [I]>,
2275 DS: Deref<Target = [N]>,
2276{
2277 fn rows(&self) -> usize {
2278 (*self).rows()
2279 }
2280
2281 fn cols(&self) -> usize {
2282 (*self).cols()
2283 }
2284
2285 fn nnz(&self) -> usize {
2286 (*self).nnz()
2287 }
2288}
2289
2290impl<'a, N, I, IpS, IS, DS, Iptr> IntoIterator
2291 for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
2292where
2293 I: 'a + SpIndex,
2294 Iptr: 'a + SpIndex,
2295 N: 'a,
2296 IpS: Deref<Target = [Iptr]>,
2297 IS: Deref<Target = [I]>,
2298 DS: Deref<Target = [N]>,
2299{
2300 type Item = (&'a N, (I, I));
2301 type IntoIter = CsIter<'a, N, I, Iptr>;
2302 fn into_iter(self) -> Self::IntoIter {
2303 self.iter()
2304 }
2305}
2306
2307impl<'a, N, I, Iptr> IntoIterator for CsMatViewI<'a, N, I, Iptr>
2308where
2309 I: 'a + SpIndex,
2310 Iptr: 'a + SpIndex,
2311 N: 'a,
2312{
2313 type Item = (&'a N, (I, I));
2314 type IntoIter = CsIter<'a, N, I, Iptr>;
2315 fn into_iter(self) -> Self::IntoIter {
2316 self.iter_rbr()
2317 }
2318}
2319
2320#[cfg(test)]
2321mod test {
2322 use super::CompressedStorage::CSR;
2323 use crate::errors::StructureErrorKind;
2324 use crate::sparse::{CsMat, CsMatI, CsMatView, CsVec};
2325 use crate::test_data::{mat1, mat1_csc, mat1_times_2};
2326 use ndarray::{arr2, Array};
2327
2328 #[test]
2329 fn test_copy() {
2330 let m = mat1();
2331 let view1 = m.view();
2332 let view2 = view1; assert_eq!(view1, view2);
2334 }
2335
2336 #[test]
2337 fn test_new_csr_success() {
2338 let indptr_ok: &[usize] = &[0, 1, 2, 3];
2339 let indices_ok: &[usize] = &[0, 1, 2];
2340 let data_ok: &[f64] = &[1., 1., 1.];
2341 let m = CsMatView::try_new((3, 3), indptr_ok, indices_ok, data_ok);
2342 assert!(m.is_ok());
2343 }
2344
2345 #[test]
2346 #[should_panic]
2347 fn test_new_csr_bad_indptr_length() {
2348 let indptr_fail1: &[usize] = &[0, 1, 2];
2349 let indices_ok: &[usize] = &[0, 1, 2];
2350 let data_ok: &[f64] = &[1., 1., 1.];
2351 let res = CsMatView::try_new((3, 3), indptr_fail1, indices_ok, data_ok);
2352 res.unwrap(); }
2354
2355 #[test]
2356 #[should_panic]
2357 fn test_new_csr_out_of_bounds_index() {
2358 let indptr_ok: &[usize] = &[0, 1, 2, 3];
2359 let data_ok: &[f64] = &[1., 1., 1.];
2360 let indices_fail2: &[usize] = &[0, 1, 4];
2361 let res = CsMatView::try_new((3, 3), indptr_ok, indices_fail2, data_ok);
2362 res.unwrap(); }
2364
2365 #[test]
2366 #[should_panic]
2367 fn test_new_csr_bad_nnz_count() {
2368 let indices_ok: &[usize] = &[0, 1, 2];
2369 let data_ok: &[f64] = &[1., 1., 1.];
2370 let indptr_fail2: &[usize] = &[0, 1, 2, 4];
2371 let res = CsMatView::try_new((3, 3), indptr_fail2, indices_ok, data_ok);
2372 res.unwrap(); }
2374
2375 #[test]
2376 #[should_panic]
2377 fn test_new_csr_data_indices_mismatch1() {
2378 let indptr_ok: &[usize] = &[0, 1, 2, 3];
2379 let data_ok: &[f64] = &[1., 1., 1.];
2380 let indices_fail1: &[usize] = &[0, 1];
2381 let res = CsMatView::try_new((3, 3), indptr_ok, indices_fail1, data_ok);
2382 res.unwrap(); }
2384
2385 #[test]
2386 #[should_panic]
2387 fn test_new_csr_data_indices_mismatch2() {
2388 let indptr_ok: &[usize] = &[0, 1, 2, 3];
2389 let indices_ok: &[usize] = &[0, 1, 2];
2390 let data_fail1: &[f64] = &[1., 1., 1., 1.];
2391 let res = CsMatView::try_new((3, 3), indptr_ok, indices_ok, data_fail1);
2392 res.unwrap(); }
2394
2395 #[test]
2396 #[should_panic]
2397 fn test_new_csr_data_indices_mismatch3() {
2398 let indptr_ok: &[usize] = &[0, 1, 2, 3];
2399 let indices_ok: &[usize] = &[0, 1, 2];
2400 let data_fail2: &[f64] = &[1., 1.];
2401 let res = CsMatView::try_new((3, 3), indptr_ok, indices_ok, data_fail2);
2402 res.unwrap(); }
2404
2405 #[test]
2406 fn test_new_csr_fails() {
2407 let indices_ok: &[usize] = &[0, 1, 2];
2408 let data_ok: &[f64] = &[1., 1., 1.];
2409 let indptr_fail3: &[usize] = &[0, 2, 1, 3];
2410 assert_eq!(
2411 CsMatView::try_new((3, 3), indptr_fail3, indices_ok, data_ok)
2412 .unwrap_err()
2413 .3
2414 .kind(),
2415 StructureErrorKind::Unsorted
2416 );
2417 }
2418
2419 #[test]
2420 fn test_new_csr_fail_indices_ordering() {
2421 let indptr: &[usize] = &[0, 2, 4, 5, 6, 7];
2422 let indices: &[usize] = &[3, 2, 3, 4, 2, 1, 3];
2424 let data: &[f64] = &[
2425 0.35310881, 0.42380633, 0.28035896, 0.58082095, 0.53350123,
2426 0.88132896, 0.72527863,
2427 ];
2428 assert_eq!(
2429 CsMatView::try_new((5, 5), indptr, indices, data)
2430 .unwrap_err()
2431 .3
2432 .kind(),
2433 StructureErrorKind::Unsorted
2434 );
2435 }
2436
2437 #[test]
2438 fn test_new_csr_csc_success() {
2439 let indptr_ok: &[usize] = &[0, 2, 5, 6];
2440 let indices_ok: &[usize] = &[2, 3, 1, 2, 3, 3];
2441 let data_ok: &[f64] = &[
2442 0.05734571, 0.15543348, 0.75628258, 0.83054515, 0.71851547,
2443 0.46202352,
2444 ];
2445 assert!(
2446 CsMatView::try_new((3, 4), indptr_ok, indices_ok, data_ok).is_ok()
2447 );
2448 assert!(
2449 CsMatView::try_new_csc((4, 3), indptr_ok, indices_ok, data_ok)
2450 .is_ok()
2451 );
2452 }
2453
2454 #[test]
2455 #[should_panic]
2456 fn test_new_csc_bad_indptr_length() {
2457 let indptr_ok: &[usize] = &[0, 2, 5, 6];
2458 let indices_ok: &[usize] = &[2, 3, 1, 2, 3, 3];
2459 let data_ok: &[f64] = &[
2460 0.05734571, 0.15543348, 0.75628258, 0.83054515, 0.71851547,
2461 0.46202352,
2462 ];
2463 let res =
2464 CsMatView::try_new_csc((3, 4), indptr_ok, indices_ok, data_ok);
2465 res.unwrap(); }
2467
2468 #[test]
2469 fn test_new_csr_vec_borrowed() {
2470 let indptr_ok = vec![0, 1, 2, 3];
2471 let indices_ok = vec![0, 1, 2];
2472 let data_ok: Vec<f64> = vec![1., 1., 1.];
2473 assert!(
2474 CsMatView::try_new((3, 3), &indptr_ok, &indices_ok, &data_ok)
2475 .is_ok()
2476 );
2477 }
2478
2479 #[test]
2480 fn test_new_csr_vec_owned() {
2481 let indptr_ok = vec![0, 1, 2, 3];
2482 let indices_ok = vec![0, 1, 2];
2483 let data_ok: Vec<f64> = vec![1., 1., 1.];
2484 assert!(CsMat::new_from_unsorted(
2485 (3, 3),
2486 indptr_ok,
2487 indices_ok,
2488 data_ok
2489 )
2490 .is_ok());
2491 }
2492
2493 #[test]
2494 fn test_csr_from_dense() {
2495 let m = Array::eye(3);
2496 let m_sparse = CsMat::csr_from_dense(m.view(), 0.);
2497
2498 assert_eq!(m_sparse, CsMat::eye(3));
2499
2500 let m = arr2(&[
2501 [1., 0., 2., 1e-7, 1.],
2502 [0., 0., 0., 1., 0.],
2503 [3., 0., 1., 0., 0.],
2504 ]);
2505 let m_sparse = CsMat::csr_from_dense(m.view(), 1e-5);
2506
2507 let expected_output = CsMat::new(
2508 (3, 5),
2509 vec![0, 3, 4, 6],
2510 vec![0, 2, 4, 3, 0, 2],
2511 vec![1., 2., 1., 1., 3., 1.],
2512 );
2513
2514 assert_eq!(m_sparse, expected_output);
2515 }
2516
2517 #[test]
2518 fn test_csc_from_dense() {
2519 let m = Array::eye(3);
2520 let m_sparse = CsMat::csc_from_dense(m.view(), 0.);
2521
2522 assert_eq!(m_sparse, CsMat::eye_csc(3));
2523
2524 let m = arr2(&[
2525 [1., 0., 2., 1e-7, 1.],
2526 [0., 0., 0., 1., 0.],
2527 [3., 0., 1., 0., 0.],
2528 ]);
2529 let m_sparse = CsMat::csc_from_dense(m.view(), 1e-5);
2530
2531 let expected_output = CsMat::new_csc(
2532 (3, 5),
2533 vec![0, 2, 2, 4, 5, 6],
2534 vec![0, 2, 0, 2, 1, 0],
2535 vec![1., 3., 2., 1., 1., 1.],
2536 );
2537
2538 assert_eq!(m_sparse, expected_output);
2539 }
2540
2541 #[test]
2542 fn owned_csr_unsorted_indices() {
2543 let indptr = vec![0, 3, 3, 5, 6, 7];
2544 let indices_sorted = &[1, 2, 3, 2, 3, 4, 4];
2545 let indices_shuffled = vec![1, 3, 2, 2, 3, 4, 4];
2546 let mut data: Vec<i32> = (0..7).collect();
2547 let m = CsMat::new_from_unsorted(
2548 (5, 5),
2549 indptr,
2550 indices_shuffled,
2551 data.clone(),
2552 )
2553 .unwrap();
2554 assert_eq!(m.indices(), indices_sorted);
2555 data.swap(1, 2);
2556 assert_eq!(m.data(), &data[..]);
2557 }
2558
2559 #[test]
2560 fn new_csr_with_empty_row() {
2561 let indptr: &[usize] = &[0, 3, 3, 5, 6, 7];
2562 let indices: &[usize] = &[1, 2, 3, 2, 3, 4, 4];
2563 let data: &[f64] = &[
2564 0.75672424, 0.1649078, 0.30140296, 0.10358244, 0.6283315,
2565 0.39244208, 0.57202407,
2566 ];
2567 assert!(CsMatView::try_new((5, 5), indptr, indices, data).is_ok());
2568 }
2569
2570 #[test]
2571 fn csr_to_csc() {
2572 let a = mat1();
2573 let a_csc_ground_truth = mat1_csc();
2574 let a_csc = a.to_other_storage();
2575 assert_eq!(a_csc, a_csc_ground_truth);
2576 }
2577
2578 #[test]
2579 fn test_self_smul() {
2580 let mut a = mat1();
2581 a.scale(2.);
2582 let c_true = mat1_times_2();
2583 assert_eq!(a.indptr(), c_true.indptr());
2584 assert_eq!(a.indices(), c_true.indices());
2585 assert_eq!(a.data(), c_true.data());
2586 }
2587
2588 #[test]
2589 fn outer_block_iter() {
2590 let mat: CsMat<f64> = CsMat::eye(11);
2591 let mut block_iter = mat.outer_block_iter(3);
2592 assert_eq!(block_iter.next().unwrap().rows(), 3);
2593 assert_eq!(block_iter.next().unwrap().rows(), 3);
2594 assert_eq!(block_iter.next().unwrap().rows(), 3);
2595 assert_eq!(block_iter.next().unwrap().rows(), 2);
2596 assert_eq!(block_iter.next(), None);
2597
2598 let mut block_iter = mat.outer_block_iter(4);
2599 assert_eq!(block_iter.next().unwrap().cols(), 11);
2600 block_iter.next().unwrap();
2601 block_iter.next().unwrap();
2602 assert_eq!(block_iter.next(), None);
2603 }
2604
2605 #[test]
2606 fn middle_outer_views() {
2607 let size = 11;
2608 let csr: CsMat<f64> = CsMat::eye(size);
2609 #[allow(deprecated)]
2610 let v = csr.view().middle_outer_views(1, 3);
2611 assert_eq!(v.shape(), (3, size));
2612 assert_eq!(v.nnz(), 3);
2613
2614 let csc = csr.to_other_storage();
2615 #[allow(deprecated)]
2616 let v = csc.view().middle_outer_views(1, 3);
2617 assert_eq!(v.shape(), (size, 3));
2618 assert_eq!(v.nnz(), 3);
2619 }
2620
2621 #[test]
2622 fn nnz_index() {
2623 let mat: CsMat<f64> = CsMat::eye(11);
2624
2625 assert_eq!(mat.nnz_index(2, 3), None);
2626 assert_eq!(mat.nnz_index(5, 7), None);
2627 assert_eq!(mat.nnz_index(0, 11), None);
2628 assert_eq!(mat.nnz_index(0, 0), Some(super::NnzIndex(0)));
2629 assert_eq!(mat.nnz_index(7, 7), Some(super::NnzIndex(7)));
2630 assert_eq!(mat.nnz_index(10, 10), Some(super::NnzIndex(10)));
2631
2632 let index = mat.nnz_index(8, 8).unwrap();
2633 assert_eq!(mat[index], 1.);
2634 let mut mat = mat;
2635 mat[index] = 2.;
2636 assert_eq!(mat[index], 2.);
2637 }
2638
2639 #[test]
2640 fn index() {
2641 let mat = CsMat::new_csc(
2645 (3, 3),
2646 vec![0, 1, 3, 4],
2647 vec![1, 0, 2, 2],
2648 vec![1., 2., 3., 4.],
2649 );
2650 assert_eq!(mat[[1, 0]], 1.);
2651 assert_eq!(mat[[0, 1]], 2.);
2652 assert_eq!(mat[[2, 1]], 3.);
2653 assert_eq!(mat[[2, 2]], 4.);
2654 assert_eq!(mat.get(0, 0), None);
2655 assert_eq!(mat.get(4, 4), None);
2656 }
2657
2658 #[test]
2659 fn get_mut() {
2660 let mut mat = CsMat::new_csc(
2664 (3, 3),
2665 vec![0, 1, 3, 4],
2666 vec![1, 0, 2, 2],
2667 vec![1.; 4],
2668 );
2669
2670 *mat.get_mut(2, 1).unwrap() = 3.;
2671
2672 let exp = CsMat::new_csc(
2673 (3, 3),
2674 vec![0, 1, 3, 4],
2675 vec![1, 0, 2, 2],
2676 vec![1., 1., 3., 1.],
2677 );
2678
2679 assert_eq!(mat, exp);
2680
2681 mat[[2, 2]] = 5.;
2682 let exp = CsMat::new_csc(
2683 (3, 3),
2684 vec![0, 1, 3, 4],
2685 vec![1, 0, 2, 2],
2686 vec![1., 1., 3., 5.],
2687 );
2688
2689 assert_eq!(mat, exp);
2690 }
2691
2692 #[test]
2693 fn map() {
2694 let mat = CsMat::new_csc(
2698 (3, 3),
2699 vec![0, 1, 3, 4],
2700 vec![1, 0, 2, 2],
2701 vec![1.; 4],
2702 );
2703
2704 let mut res = mat.map(|&x| x + 2.);
2705 let expected = CsMat::new_csc(
2706 (3, 3),
2707 vec![0, 1, 3, 4],
2708 vec![1, 0, 2, 2],
2709 vec![3.; 4],
2710 );
2711 assert_eq!(res, expected);
2712
2713 res.map_inplace(|&x| x / 3.);
2714 assert_eq!(res, mat);
2715 }
2716
2717 #[test]
2718 fn insert() {
2719 let mut mat = CsMat::empty(CSR, 0);
2723 mat.reserve_outer_dim(3);
2724 mat.reserve_nnz(4);
2725 mat.insert(0, 1, 1.);
2728 mat.insert(1, 0, 1.);
2729 mat.insert(2, 1, 1.);
2730 mat.insert(2, 2, 1.);
2731
2732 let expected =
2733 CsMat::new((3, 3), vec![0, 1, 2, 4], vec![1, 0, 1, 2], vec![1.; 4]);
2734 assert_eq!(mat, expected);
2735
2736 mat.insert(0, 0, 2.);
2742 let expected = CsMat::new(
2743 (3, 3),
2744 vec![0, 2, 3, 5],
2745 vec![0, 1, 0, 1, 2],
2746 vec![2., 1., 1., 1., 1.],
2747 );
2748 assert_eq!(mat, expected);
2749
2750 mat.insert(1, 0, 3.);
2756 let expected = CsMat::new(
2757 (3, 3),
2758 vec![0, 2, 3, 5],
2759 vec![0, 1, 0, 1, 2],
2760 vec![2., 1., 3., 1., 1.],
2761 );
2762 assert_eq!(mat, expected);
2763 }
2764
2765 #[test]
2766 fn bug_129() {
2768 let mut mat = CsMat::zero((3, 100));
2769 mat.insert(2, 3, 42);
2770 let mut iter = mat.iter();
2771 assert_eq!(iter.next(), Some((&42, (2, 3))));
2772 assert_eq!(iter.next(), None);
2773 }
2774
2775 #[test]
2776 fn iter_mut() {
2777 let mut mat = CsMat::new_csc(
2781 (3, 3),
2782 vec![0, 1, 3, 4],
2783 vec![1, 0, 2, 2],
2784 vec![1.; 4],
2785 );
2786
2787 for mut col_vec in mat.outer_iterator_mut() {
2788 for (row_ind, val) in col_vec.iter_mut() {
2789 *val = row_ind as f64 + 1.;
2790 }
2791 }
2792
2793 let expected = CsMat::new_csc(
2794 (3, 3),
2795 vec![0, 1, 3, 4],
2796 vec![1, 0, 2, 2],
2797 vec![2., 1., 3., 3.],
2798 );
2799 assert_eq!(mat, expected);
2800 }
2801
2802 #[test]
2803 #[should_panic]
2804 fn modify_fail() {
2805 let mut mat = CsMat::new_csc(
2806 (3, 3),
2807 vec![0, 1, 3, 4],
2808 vec![1, 0, 2, 2],
2809 vec![1.; 4],
2810 );
2811
2812 mat.modify(|indptr, indices, data| {
2815 indptr[1] = 2;
2816 indptr[2] = 4;
2817 indices[0] = 0;
2818 indices[1] = 1;
2819 data[2] = 2.;
2820 });
2821 }
2822
2823 #[test]
2824 fn convert_types() {
2825 let mat: CsMat<f32> = CsMat::eye(3);
2826 let mat_: CsMatI<f64, u32> = mat.to_other_types();
2827 assert_eq!(mat_.indptr(), &[0, 1, 2, 3][..]);
2828
2829 let mat = CsMatI::new_csc(
2830 (3, 3),
2831 vec![0u32, 1, 3, 4],
2832 vec![1, 0, 2, 2],
2833 vec![1.; 4],
2834 );
2835 let mat_: CsMatI<f32, usize, u32> = mat.to_other_types();
2836 assert_eq!(mat_.indptr(), &[0, 1, 3, 4][..]);
2837 assert_eq!(mat_.data(), &[1.0f32, 1., 1., 1.]);
2838 }
2839
2840 #[test]
2841 fn iter() {
2842 let mat = CsMat::new_csc(
2843 (3, 3),
2844 vec![0, 1, 3, 4],
2845 vec![1, 0, 2, 2],
2846 vec![1.; 4],
2847 );
2848 let mut iter = mat.iter();
2849 assert_eq!(iter.next(), Some((&1., (1, 0))));
2850 assert_eq!(iter.next(), Some((&1., (0, 1))));
2851 assert_eq!(iter.next(), Some((&1., (2, 1))));
2852 assert_eq!(iter.next(), Some((&1., (2, 2))));
2853 assert_eq!(iter.next(), None);
2854 }
2855
2856 #[test]
2857 fn degrees() {
2858 let mat = CsMat::new_csc(
2864 (5, 5),
2865 vec![0, 3, 4, 5, 8, 10],
2866 vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4],
2867 vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1],
2868 );
2869
2870 let degrees = mat.degrees();
2871 assert_eq!(°rees, &[2, 0, 1, 2, 1],);
2872 }
2873
2874 #[test]
2875 fn diag() {
2876 let mat = CsMat::new_csc(
2882 (5, 5),
2883 vec![0, 3, 4, 5, 8, 10],
2884 vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4],
2885 vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1],
2886 );
2887
2888 let diag = mat.diag();
2889 let expected = CsVec::new(5, vec![0, 1, 3, 4], vec![1, 2, 1, 1]);
2890 assert_eq!(diag, expected);
2891
2892 let mut iter = mat.diag_iter();
2893 assert_eq!(iter.next().unwrap(), Some(&1));
2894 assert_eq!(iter.next().unwrap(), Some(&2));
2895 assert_eq!(iter.next().unwrap(), None);
2896 assert_eq!(iter.next().unwrap(), Some(&1));
2897 assert_eq!(iter.next().unwrap(), Some(&1));
2898 assert_eq!(iter.next(), None);
2899 }
2900
2901 #[test]
2902 #[cfg_attr(miri, ignore)]
2903 fn diag_mut() {
2904 let mut mat = CsMat::new_csc(
2910 (5, 5),
2911 vec![0, 3, 4, 5, 8, 10],
2912 vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4],
2913 vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1],
2914 );
2915
2916 let mut diags = mat.diag_iter_mut().collect::<Vec<_>>();
2917 diags[4].as_mut().map(|x| **x *= 3);
2918 diags[3].as_mut().map(|x| **x -= 4);
2919 let expected = CsVec::new(5, vec![0, 1, 3, 4], vec![1, 2, -3, 3]);
2920 assert_eq!(mat.diag(), expected);
2921 }
2922
2923 #[test]
2924 fn diag_rectangular() {
2925 let mat = CsMat::new_csc(
2931 (5, 6),
2932 vec![0, 3, 4, 5, 8, 10, 12],
2933 vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4, 0, 2],
2934 vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1, 3, 1],
2935 );
2936
2937 let diag = mat.diag();
2938 let expected = CsVec::new(5, vec![0, 1, 3, 4], vec![1, 2, 1, 1]);
2939 assert_eq!(diag, expected);
2940
2941 let mut iter = mat.diag_iter();
2942 assert_eq!(iter.next().unwrap(), Some(&1));
2943 assert_eq!(iter.next().unwrap(), Some(&2));
2944 assert_eq!(iter.next().unwrap(), None);
2945 assert_eq!(iter.next().unwrap(), Some(&1));
2946 assert_eq!(iter.next().unwrap(), Some(&1));
2947 assert_eq!(iter.next(), None);
2948 }
2949
2950 #[test]
2951 fn onehot_zero() {
2952 let onehot: CsMat<f32> = CsMat::zero((3, 3)).to_inner_onehot();
2953
2954 assert!(onehot.is_csr());
2955 assert_eq!(CsMat::zero((3, 3)), onehot);
2956 }
2957
2958 #[test]
2959 fn onehot_eye() {
2960 let mat = CsMat::new(
2961 (2, 2),
2962 vec![0, 2, 4],
2963 vec![0, 1, 0, 1],
2964 vec![2.0, 0.0, 0.0, 2.0],
2965 );
2966
2967 let onehot = mat.to_inner_onehot();
2968
2969 assert!(onehot.is_csr());
2970 assert_eq!(CsMat::eye(2), onehot);
2971 }
2972
2973 #[test]
2974 fn onehot_sparse_csc() {
2975 let mat = CsMat::new_csc((2, 3), vec![0, 0, 1, 1], vec![1], vec![2.0]);
2976
2977 let onehot = mat.to_inner_onehot();
2978
2979 let expected =
2980 CsMat::new_csc((2, 3), vec![0, 0, 1, 1], vec![1], vec![1.0]);
2981
2982 assert!(onehot.is_csc());
2983 assert_eq!(expected, onehot);
2984 }
2985
2986 #[test]
2987 fn onehot_ignores_nan() {
2988 let mat = CsMat::new(
2989 (2, 2),
2990 vec![0, 2, 3],
2991 vec![0, 1, 1],
2992 vec![2.0, std::f64::NAN, 2.0],
2993 );
2994
2995 let onehot = mat.to_inner_onehot();
2996
2997 assert!(onehot.is_csr());
2998 assert_eq!(CsMat::eye(2), onehot);
2999 }
3000
3001 #[test]
3002 fn mul_assign() {
3003 let mut m1 = crate::TriMat::new((6, 9));
3004 m1.add_triplet(1, 1, 8_i32);
3005 m1.add_triplet(1, 2, 7);
3006 m1.add_triplet(0, 1, 6);
3007 m1.add_triplet(0, 8, 5);
3008 m1.add_triplet(4, 2, 4);
3009 let mut m1: CsMat<_> = m1.to_csr();
3010
3011 m1 *= 2;
3012 for (&v, (j, i)) in m1.iter() {
3013 match (j, i) {
3014 (1, 1) => assert_eq!(v, 16),
3015 (1, 2) => assert_eq!(v, 14),
3016 (0, 1) => assert_eq!(v, 12),
3017 (0, 8) => assert_eq!(v, 10),
3018 (4, 2) => assert_eq!(v, 8),
3019 _ => panic!(),
3020 }
3021 }
3022 }
3023
3024 #[test]
3025 fn div_assign() {
3026 let mut m1 = crate::TriMat::new((6, 9));
3027 m1.add_triplet(1, 1, 8_i32);
3028 m1.add_triplet(1, 2, 7);
3029 m1.add_triplet(0, 1, 6);
3030 m1.add_triplet(0, 8, 5);
3031 m1.add_triplet(4, 2, 4);
3032 let mut m1: CsMat<_> = m1.to_csr();
3033
3034 m1 /= 2;
3035 for (&v, (j, i)) in m1.iter() {
3036 match (j, i) {
3037 (1, 1) => assert_eq!(v, 4),
3038 (1, 2) => assert_eq!(v, 3),
3039 (0, 1) => assert_eq!(v, 3),
3040 (0, 8) => assert_eq!(v, 2),
3041 (4, 2) => assert_eq!(v, 2),
3042 _ => panic!(),
3043 }
3044 }
3045 }
3046
3047 #[test]
3048 fn issue_99() {
3049 let a = crate::TriMat::<i32>::new((10, 1)).to_csc::<usize>();
3050 let b = crate::TriMat::<i32>::new((1, 9)).to_csr();
3051 let _c = &a * &b;
3052 }
3053}
3054
3055#[cfg(feature = "approx")]
3056mod approx_impls {
3057 use super::*;
3058 use approx::*;
3059
3060 impl<N, I, Iptr, IS1, DS1, ISptr1, IS2, ISptr2, DS2>
3061 AbsDiffEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>
3062 for CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>
3063 where
3064 I: SpIndex,
3065 Iptr: SpIndex,
3066 CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>:
3067 std::cmp::PartialEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>,
3068 IS1: Deref<Target = [I]>,
3069 IS2: Deref<Target = [I]>,
3070 ISptr1: Deref<Target = [Iptr]>,
3071 ISptr2: Deref<Target = [Iptr]>,
3072 DS1: Deref<Target = [N]>,
3073 DS2: Deref<Target = [N]>,
3074 N: AbsDiffEq,
3075 N::Epsilon: Clone,
3076 N: num_traits::Zero,
3077 {
3078 type Epsilon = N::Epsilon;
3079 fn default_epsilon() -> N::Epsilon {
3080 N::default_epsilon()
3081 }
3082 fn abs_diff_eq(
3083 &self,
3084 other: &CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>,
3085 epsilon: N::Epsilon,
3086 ) -> bool {
3087 if self.shape() != other.shape() {
3088 return false;
3089 }
3090 if self.storage() == other.storage() {
3091 self.outer_iterator()
3092 .zip(other.outer_iterator())
3093 .all(|(r1, r2)| r1.abs_diff_eq(&r2, epsilon.clone()))
3094 } else {
3095 let all_matching = self.iter().all(|(n, (i, j))| {
3098 n.abs_diff_eq(
3099 other
3100 .get(i.to_usize().unwrap(), j.to_usize().unwrap())
3101 .unwrap_or(&N::zero()),
3102 epsilon.clone(),
3103 )
3104 });
3105 if !all_matching {
3106 return false;
3107 }
3108
3109 other.iter().all(|(n, (i, j))| {
3111 n.abs_diff_eq(
3112 self.get(i.to_usize().unwrap(), j.to_usize().unwrap())
3113 .unwrap_or(&N::zero()),
3114 epsilon.clone(),
3115 )
3116 })
3117 }
3118 }
3119 }
3120 impl<N, I, Iptr, IS1, DS1, ISptr1, IS2, ISptr2, DS2>
3121 UlpsEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>
3122 for CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>
3123 where
3124 I: SpIndex,
3125 Iptr: SpIndex,
3126 CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>:
3127 std::cmp::PartialEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>,
3128 IS1: Deref<Target = [I]>,
3129 IS2: Deref<Target = [I]>,
3130 ISptr1: Deref<Target = [Iptr]>,
3131 ISptr2: Deref<Target = [Iptr]>,
3132 DS1: Deref<Target = [N]>,
3133 DS2: Deref<Target = [N]>,
3134 N: UlpsEq,
3135 N::Epsilon: Clone,
3136 N: num_traits::Zero,
3137 {
3138 fn default_max_ulps() -> u32 {
3139 N::default_max_ulps()
3140 }
3141 fn ulps_eq(
3142 &self,
3143 other: &CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>,
3144 epsilon: N::Epsilon,
3145 max_ulps: u32,
3146 ) -> bool {
3147 if self.shape() != other.shape() {
3148 return false;
3149 }
3150 if self.storage() == other.storage() {
3151 self.outer_iterator()
3152 .zip(other.outer_iterator())
3153 .all(|(r1, r2)| r1.ulps_eq(&r2, epsilon.clone(), max_ulps))
3154 } else {
3155 let all_matches = self.iter().all(|(n, (i, j))| {
3158 n.ulps_eq(
3159 other
3160 .get(i.to_usize().unwrap(), j.to_usize().unwrap())
3161 .unwrap_or(&N::zero()),
3162 epsilon.clone(),
3163 max_ulps,
3164 )
3165 });
3166 if !all_matches {
3167 return false;
3168 }
3169
3170 other.iter().all(|(n, (i, j))| {
3172 n.ulps_eq(
3173 self.get(i.to_usize().unwrap(), j.to_usize().unwrap())
3174 .unwrap_or(&N::zero()),
3175 epsilon.clone(),
3176 max_ulps,
3177 )
3178 })
3179 }
3180 }
3181 }
3182 impl<N, I, Iptr, IS1, DS1, ISptr1, IS2, ISptr2, DS2>
3183 RelativeEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>
3184 for CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>
3185 where
3186 I: SpIndex,
3187 Iptr: SpIndex,
3188 CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>:
3189 std::cmp::PartialEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>,
3190 IS1: Deref<Target = [I]>,
3191 IS2: Deref<Target = [I]>,
3192 ISptr1: Deref<Target = [Iptr]>,
3193 ISptr2: Deref<Target = [Iptr]>,
3194 DS1: Deref<Target = [N]>,
3195 DS2: Deref<Target = [N]>,
3196 N: RelativeEq,
3197 N::Epsilon: Clone,
3198 N: num_traits::Zero,
3199 {
3200 fn default_max_relative() -> N::Epsilon {
3201 N::default_max_relative()
3202 }
3203 fn relative_eq(
3204 &self,
3205 other: &CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>,
3206 epsilon: N::Epsilon,
3207 max_relative: Self::Epsilon,
3208 ) -> bool {
3209 if self.shape() != other.shape() {
3210 return false;
3211 }
3212 if self.storage() == other.storage() {
3213 self.outer_iterator().zip(other.outer_iterator()).all(
3214 |(r1, r2)| {
3215 r1.relative_eq(
3216 &r2,
3217 epsilon.clone(),
3218 max_relative.clone(),
3219 )
3220 },
3221 )
3222 } else {
3223 let all_matches = self.iter().all(|(n, (i, j))| {
3226 n.relative_eq(
3227 other
3228 .get(i.to_usize().unwrap(), j.to_usize().unwrap())
3229 .unwrap_or(&N::zero()),
3230 epsilon.clone(),
3231 max_relative.clone(),
3232 )
3233 });
3234 if !all_matches {
3235 return false;
3236 }
3237
3238 other.iter().all(|(n, (i, j))| {
3240 n.relative_eq(
3241 self.get(i.to_usize().unwrap(), j.to_usize().unwrap())
3242 .unwrap_or(&N::zero()),
3243 epsilon.clone(),
3244 max_relative.clone(),
3245 )
3246 })
3247 }
3248 }
3249 }
3250
3251 #[cfg(test)]
3252 mod tests {
3253 use crate::*;
3254
3255 #[test]
3256 fn different_shapes() {
3257 let mut m1 = TriMat::new((3, 2));
3258 m1.add_triplet(1, 1, 8_u8);
3259 let m1: CsMat<_> = m1.to_csr();
3260 let mut m2 = TriMat::new((2, 3));
3261 m2.add_triplet(1, 1, 8_u8);
3262 let m2 = m2.to_csr();
3263
3264 ::approx::assert_abs_diff_ne!(m1, m2);
3265 ::approx::assert_abs_diff_ne!(m1, m2.to_csc());
3266 ::approx::assert_abs_diff_ne!(m1.to_csc(), m2);
3267 ::approx::assert_abs_diff_ne!(m1.to_csc(), m2.to_csc());
3268 }
3269
3270 #[test]
3271 fn equal_elements() {
3272 let mut m1 = TriMat::new((6, 9));
3273 m1.add_triplet(1, 1, 8_u8);
3274 m1.add_triplet(1, 2, 7_u8);
3275 m1.add_triplet(0, 1, 6_u8);
3276 m1.add_triplet(0, 8, 5_u8);
3277 m1.add_triplet(4, 2, 4_u8);
3278
3279 let m1: CsMat<_> = m1.to_csr();
3280 let m2 = m1.clone();
3281
3282 ::approx::assert_abs_diff_eq!(m1, m2, epsilon = 0);
3283 ::approx::assert_abs_diff_eq!(m1.to_csc(), m2, epsilon = 0);
3284 ::approx::assert_abs_diff_eq!(m1, m2.to_csc(), epsilon = 0);
3285 ::approx::assert_abs_diff_eq!(
3286 m1.to_csc(),
3287 m2.to_csc(),
3288 epsilon = 0
3289 );
3290
3291 let mut m1 = TriMat::new((6, 9));
3292 m1.add_triplet(1, 1, 8.0_f32);
3293 m1.add_triplet(1, 2, 7.0);
3294 m1.add_triplet(0, 1, 6.0);
3295 m1.add_triplet(0, 8, 5.0);
3296 m1.add_triplet(4, 2, 4.0);
3297
3298 let m1: CsMat<_> = m1.to_csr();
3299 let m2 = m1.clone();
3300
3301 ::approx::assert_abs_diff_eq!(m1, m2);
3302 ::approx::assert_abs_diff_eq!(m1.to_csc(), m2);
3303 ::approx::assert_abs_diff_eq!(m1, m2.to_csc());
3304 ::approx::assert_abs_diff_eq!(m1.to_csc(), m2.to_csc());
3305
3306 ::approx::assert_relative_eq!(m1, m2);
3307 ::approx::assert_relative_eq!(m1.to_csc(), m2);
3308 ::approx::assert_relative_eq!(m1, m2.to_csc());
3309 ::approx::assert_relative_eq!(m1.to_csc(), m2.to_csc());
3310
3311 ::approx::assert_ulps_eq!(m1, m2);
3312 ::approx::assert_ulps_eq!(m1.to_csc(), m2);
3313 ::approx::assert_ulps_eq!(m1, m2.to_csc());
3314 ::approx::assert_ulps_eq!(m1.to_csc(), m2.to_csc());
3315 }
3316
3317 #[test]
3318 fn almost_equal_elements() {
3319 let mut m1 = TriMat::new((6, 9));
3320 m1.add_triplet(1, 1, 8.0_f32);
3321 m1.add_triplet(1, 2, 7.0);
3322 m1.add_triplet(0, 1, 6.0);
3323 m1.add_triplet(0, 8, 5.0);
3324 m1.add_triplet(4, 2, 4.0);
3325 let m1: CsMat<_> = m1.to_csr();
3326
3327 let mut m2 = TriMat::new((6, 9));
3328 m2.add_triplet(1, 1, 8.0_f32);
3329 m2.add_triplet(1, 2, 7.0 - 0.5); m2.add_triplet(0, 1, 6.0);
3331 m2.add_triplet(0, 8, 5.0);
3332 m2.add_triplet(4, 2, 4.0);
3333 m2.add_triplet(4, 3, 0.2); let m2 = m2.to_csr();
3335
3336 ::approx::assert_abs_diff_eq!(m1, m2, epsilon = 0.6);
3337 ::approx::assert_abs_diff_eq!(m1.to_csc(), m2, epsilon = 0.6);
3338 ::approx::assert_abs_diff_eq!(m1, m2.to_csc(), epsilon = 0.6);
3339 ::approx::assert_abs_diff_eq!(
3340 m1.to_csc(),
3341 m2.to_csc(),
3342 epsilon = 0.6
3343 );
3344
3345 ::approx::assert_abs_diff_ne!(m1, m2, epsilon = 0.4);
3346 ::approx::assert_abs_diff_ne!(m1.to_csc(), m2, epsilon = 0.4);
3347 ::approx::assert_abs_diff_ne!(m1, m2.to_csc(), epsilon = 0.4);
3348 ::approx::assert_abs_diff_ne!(
3349 m1.to_csc(),
3350 m2.to_csc(),
3351 epsilon = 0.4
3352 );
3353 }
3354 }
3355}