sprs_rssn/sparse/
csmat.rs

1//! A sparse matrix in the Compressed Sparse Row/Column format
2//!
3//! In the CSR format, a matrix is a structure containing three vectors:
4//! indptr, indices, and data
5//! These vectors satisfy the relation
6//! for i in [0, nrows],
7//! A(i, indices[indptr[i]..indptr[i+1]]) = data[indptr[i]..indptr[i+1]]
8//! In the CSC format, the relation is
9//! A(indices[indptr[i]..indptr[i+1]], i) = data[indptr[i]..indptr[i+1]]
10use ndarray::ArrayView;
11use num_traits::{Float, Num, Signed, Zero};
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14use std::cmp;
15use std::default::Default;
16use std::iter::{Enumerate, Zip};
17use std::mem;
18use std::ops::{Add, Deref, DerefMut, Index, IndexMut, Mul, MulAssign};
19use std::slice::Iter;
20
21use crate::{Ix1, Ix2, Shape};
22use ndarray::linalg::Dot;
23use ndarray::{self, Array, ArrayBase, ShapeBuilder};
24
25use crate::indexing::SpIndex;
26
27use crate::errors::StructureError;
28use crate::sparse::binop;
29use crate::sparse::permutation::PermViewI;
30use crate::sparse::prelude::*;
31use crate::sparse::prod;
32use crate::sparse::smmp;
33use crate::sparse::to_dense::assign_to_dense;
34use crate::sparse::utils;
35use crate::sparse::vec;
36
37/// Describe the storage of a `CsMat`
38#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
39#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
40#[allow(clippy::upper_case_acronyms)]
41pub enum CompressedStorage {
42    /// Compressed row storage
43    CSR,
44    /// Compressed column storage
45    CSC,
46}
47
48impl CompressedStorage {
49    /// Get the other storage, ie return CSC if we were CSR, and vice versa
50    pub fn other_storage(self) -> Self {
51        match self {
52            CSR => CSC,
53            CSC => CSR,
54        }
55    }
56}
57
58pub fn outer_dimension(
59    storage: CompressedStorage,
60    rows: usize,
61    cols: usize,
62) -> usize {
63    match storage {
64        CSR => rows,
65        CSC => cols,
66    }
67}
68
69pub fn inner_dimension(
70    storage: CompressedStorage,
71    rows: usize,
72    cols: usize,
73) -> usize {
74    match storage {
75        CSR => cols,
76        CSC => rows,
77    }
78}
79
80pub use self::CompressedStorage::{CSC, CSR};
81
82#[derive(Clone, Copy, PartialEq, Eq, Debug)]
83/// Hold the index of a non-zero element in the compressed storage
84///
85/// An `NnzIndex` can be used to later access the non-zero element in constant
86/// time.
87pub struct NnzIndex(pub usize);
88
89pub struct CsIter<'a, N: 'a, I: 'a, Iptr: 'a = I>
90where
91    I: SpIndex,
92    Iptr: SpIndex,
93{
94    storage: CompressedStorage,
95    cur_outer: I,
96    indptr: crate::IndPtrView<'a, Iptr>,
97    inner_iter: Enumerate<Zip<Iter<'a, I>, Iter<'a, N>>>,
98}
99
100impl<'a, N, I, Iptr> Iterator for CsIter<'a, N, I, Iptr>
101where
102    I: SpIndex,
103    Iptr: SpIndex,
104    N: 'a,
105{
106    type Item = (&'a N, (I, I));
107    fn next(&mut self) -> Option<<Self as Iterator>::Item> {
108        match self.inner_iter.next() {
109            None => None,
110            Some((nnz_index, (&inner_ind, val))) => {
111                // loop to find the correct outer dimension. Looping
112                // is necessary because there can be several adjacent
113                // empty outer dimensions.
114                loop {
115                    let nnz_end = self
116                        .indptr
117                        .outer_inds_sz(self.cur_outer.index_unchecked())
118                        .end;
119                    if nnz_index == nnz_end.index_unchecked() {
120                        self.cur_outer += I::one();
121                    } else {
122                        break;
123                    }
124                }
125                let (row, col) = match self.storage {
126                    CSR => (self.cur_outer, inner_ind),
127                    CSC => (inner_ind, self.cur_outer),
128                };
129                Some((val, (row, col)))
130            }
131        }
132    }
133
134    fn size_hint(&self) -> (usize, Option<usize>) {
135        self.inner_iter.size_hint()
136    }
137}
138
139impl<N, I: SpIndex, Iptr: SpIndex, IptrStorage, IStorage, DStorage>
140    CsMatBase<N, I, IptrStorage, IStorage, DStorage, Iptr>
141where
142    IptrStorage: Deref<Target = [Iptr]>,
143    IStorage: Deref<Target = [I]>,
144    DStorage: Deref<Target = [N]>,
145{
146    pub(crate) fn new_checked(
147        storage: CompressedStorage,
148        shape: (usize, usize),
149        indptr: IptrStorage,
150        indices: IStorage,
151        data: DStorage,
152    ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)> {
153        let (nrows, ncols) = shape;
154        let (inner, outer) = match storage {
155            CSR => (ncols, nrows),
156            CSC => (nrows, ncols),
157        };
158        if data.len() != indices.len() {
159            return Err((
160                indptr,
161                indices,
162                data,
163                StructureError::SizeMismatch(
164                    "data and indices have different sizes",
165                ),
166            ));
167        }
168        match crate::sparse::utils::check_compressed_structure(
169            inner,
170            outer,
171            indptr.as_ref(),
172            indices.as_ref(),
173        ) {
174            Err(e) => Err((indptr, indices, data, e)),
175            Ok(_) => Ok(Self {
176                storage,
177                nrows,
178                ncols,
179                indptr: crate::IndPtrBase::new_trusted(indptr),
180                indices,
181                data,
182            }),
183        }
184    }
185
186    /// Create a new `CSR` sparse matrix
187    ///
188    /// See `new_csc` for the `CSC` equivalent
189    ///
190    /// This constructor can be used to construct all
191    /// sparse matrix types.
192    /// By using the type aliases one helps constrain the resulting type,
193    /// as shown below
194    ///
195    /// # Example
196    ///
197    /// ```rust
198    /// # use sprs::*;
199    /// // This creates an owned matrix
200    /// let owned_matrix = CsMat::new((2, 2), vec![0, 1, 1], vec![1], vec![4_u8]);
201    /// // This creates a matrix which only borrows the elements
202    /// let borrow_matrix = CsMatView::new((2, 2), &[0, 1, 1], &[1], &[4_u8]);
203    /// // A combination of storage types may also be used for a
204    /// // general sparse matrix
205    /// let mixed_matrix = CsMatBase::new((2, 2), &[0, 1, 1] as &[_], vec![1_i64].into_boxed_slice(), vec![4_u8]);
206    /// ```
207    pub fn new(
208        shape: (usize, usize),
209        indptr: IptrStorage,
210        indices: IStorage,
211        data: DStorage,
212    ) -> Self {
213        Self::new_checked(CompressedStorage::CSR, shape, indptr, indices, data)
214            .map_err(|(_, _, _, e)| e)
215            .unwrap()
216    }
217
218    /// Create a new `CSC` sparse matrix
219    ///
220    /// See `new` for the `CSR` equivalent
221    pub fn new_csc(
222        shape: (usize, usize),
223        indptr: IptrStorage,
224        indices: IStorage,
225        data: DStorage,
226    ) -> Self {
227        Self::new_checked(CompressedStorage::CSC, shape, indptr, indices, data)
228            .map_err(|(_, _, _, e)| e)
229            .unwrap()
230    }
231
232    /// Try to create a new `CSR` sparse matrix
233    ///
234    /// See `try_new_csc` for the `CSC` equivalent
235    pub fn try_new(
236        shape: (usize, usize),
237        indptr: IptrStorage,
238        indices: IStorage,
239        data: DStorage,
240    ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)> {
241        Self::new_checked(CompressedStorage::CSR, shape, indptr, indices, data)
242    }
243
244    /// Try to create a new `CSC` sparse matrix
245    ///
246    /// See `new` for the `CSR` equivalent
247    pub fn try_new_csc(
248        shape: (usize, usize),
249        indptr: IptrStorage,
250        indices: IStorage,
251        data: DStorage,
252    ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)> {
253        Self::new_checked(CompressedStorage::CSC, shape, indptr, indices, data)
254    }
255
256    /// Create a `CsMat` matrix from raw data,
257    /// without checking their validity
258    ///
259    /// # Safety
260    /// This is unsafe because algorithms are free to assume
261    /// that properties guaranteed by
262    /// [`check_compressed_structure`](Self::check_compressed_structure) are enforced.
263    /// For instance, non out-of-bounds indices can be relied upon to
264    /// perform unchecked slice access.
265    pub unsafe fn new_unchecked(
266        storage: CompressedStorage,
267        shape: Shape,
268        indptr: IptrStorage,
269        indices: IStorage,
270        data: DStorage,
271    ) -> Self {
272        let (nrows, ncols) = shape;
273        Self {
274            storage,
275            nrows,
276            ncols,
277            indptr: crate::IndPtrBase::new_trusted(indptr),
278            indices,
279            data,
280        }
281    }
282
283    /// Internal analog to `new_unchecked` which is not marked as `unsafe` as
284    /// we should always construct valid matrices internally
285    pub(crate) fn new_trusted(
286        storage: CompressedStorage,
287        shape: Shape,
288        indptr: IptrStorage,
289        indices: IStorage,
290        data: DStorage,
291    ) -> Self {
292        let (nrows, ncols) = shape;
293        Self {
294            storage,
295            nrows,
296            ncols,
297            indptr: crate::IndPtrBase::new_trusted(indptr),
298            indices,
299            data,
300        }
301    }
302}
303
304impl<N, I: SpIndex, Iptr: SpIndex, IptrStorage, IStorage, DStorage>
305    CsMatBase<N, I, IptrStorage, IStorage, DStorage, Iptr>
306where
307    IptrStorage: Deref<Target = [Iptr]>,
308    IStorage: DerefMut<Target = [I]>,
309    DStorage: DerefMut<Target = [N]>,
310{
311    fn new_from_unsorted_checked(
312        storage: CompressedStorage,
313        shape: (usize, usize),
314        indptr: IptrStorage,
315        mut indices: IStorage,
316        mut data: DStorage,
317    ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)>
318    where
319        N: Clone,
320    {
321        let (nrows, ncols) = shape;
322        let (inner, outer) = match storage {
323            CSR => (ncols, nrows),
324            CSC => (nrows, ncols),
325        };
326        if data.len() != indices.len() {
327            return Err((
328                indptr,
329                indices,
330                data,
331                StructureError::SizeMismatch(
332                    "data and indices have different sizes",
333                ),
334            ));
335        }
336        let mut buf = Vec::new();
337        for start_stop in indptr.windows(2) {
338            let start = start_stop[0].to_usize().unwrap();
339            let stop = start_stop[1].to_usize().unwrap();
340            let indices = &mut indices[start..stop];
341            if utils::sorted_indices(indices) {
342                continue;
343            }
344            let data = &mut data[start..stop];
345            let len = stop - start;
346            let indices = &mut indices[..len];
347            let data = &mut data[..len];
348            utils::sort_indices_data_slices(indices, data, &mut buf);
349        }
350
351        match crate::sparse::utils::check_compressed_structure(
352            inner,
353            outer,
354            indptr.as_ref(),
355            indices.as_ref(),
356        ) {
357            Err(e) => Err((indptr, indices, data, e)),
358            Ok(_) => Ok(Self {
359                storage,
360                nrows,
361                ncols,
362                indptr: crate::IndPtrBase::new_trusted(indptr),
363                indices,
364                data,
365            }),
366        }
367    }
368
369    /// Try create a `CSR` matrix which acts as an owner of its data.
370    ///
371    /// A `CSC` matrix can be created with `new_from_unsorted_csc()`.
372    ///
373    /// If necessary, the indices will be sorted in place.
374    pub fn new_from_unsorted(
375        shape: Shape,
376        indptr: IptrStorage,
377        indices: IStorage,
378        data: DStorage,
379    ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)>
380    where
381        N: Clone,
382    {
383        Self::new_from_unsorted_checked(CSR, shape, indptr, indices, data)
384    }
385
386    /// Try create a `CSC` matrix which acts as an owner of its data.
387    ///
388    /// A `CSR` matrix can be created with `new_from_unsorted_csr()`.
389    ///
390    /// If necessary, the indices will be sorted in place.
391    pub fn new_from_unsorted_csc(
392        shape: Shape,
393        indptr: IptrStorage,
394        indices: IStorage,
395        data: DStorage,
396    ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)>
397    where
398        N: Clone,
399    {
400        Self::new_from_unsorted_checked(CSC, shape, indptr, indices, data)
401    }
402}
403
404/// # Constructor methods for owned sparse matrices
405impl<N, I: SpIndex, Iptr: SpIndex> CsMatI<N, I, Iptr> {
406    /// Identity matrix, stored as a CSR matrix.
407    ///
408    /// ```rust
409    /// use sprs::{CsMat, CsVec};
410    /// let eye = CsMat::eye(5);
411    /// assert!(eye.is_csr());
412    /// let x = CsVec::new(5, vec![0, 2, 4], vec![1., 2., 3.]);
413    /// let y = &eye * &x;
414    /// assert_eq!(x, y);
415    /// ```
416    pub fn eye(dim: usize) -> Self
417    where
418        N: Num + Clone,
419    {
420        let _ = (I::from_usize(dim), Iptr::from_usize(dim)); // Make sure dim fits in type I & Iptr
421        let n = dim;
422        let indptr = (0..=n).map(Iptr::from_usize_unchecked).collect();
423        let indices = (0..n).map(I::from_usize_unchecked).collect();
424        let data = vec![N::one(); n];
425        Self::new_trusted(CSR, (n, n), indptr, indices, data)
426    }
427
428    /// Identity matrix, stored as a CSC matrix.
429    ///
430    /// ```rust
431    /// use sprs::{CsMat, CsVec};
432    /// let eye = CsMat::eye_csc(5);
433    /// assert!(eye.is_csc());
434    /// let x = CsVec::new(5, vec![0, 2, 4], vec![1., 2., 3.]);
435    /// let y = &eye * &x;
436    /// assert_eq!(x, y);
437    /// ```
438    pub fn eye_csc(dim: usize) -> Self
439    where
440        N: Num + Clone,
441    {
442        let _ = (I::from_usize(dim), Iptr::from_usize(dim)); // Make sure dim fits in type I & Iptr
443        let n = dim;
444        let indptr = (0..=n).map(Iptr::from_usize_unchecked).collect();
445        let indices = (0..n).map(I::from_usize_unchecked).collect();
446        let data = vec![N::one(); n];
447        Self::new_trusted(CSC, (n, n), indptr, indices, data)
448    }
449    /// Create an empty `CsMat` for building purposes
450    pub fn empty(storage: CompressedStorage, inner_size: usize) -> Self {
451        let shape = match storage {
452            CSR => (0, inner_size),
453            CSC => (inner_size, 0),
454        };
455        Self::new_trusted(
456            storage,
457            shape,
458            vec![Iptr::zero(); 1],
459            Vec::new(),
460            Vec::new(),
461        )
462    }
463
464    /// Create a new `CsMat` representing the zero matrix.
465    /// Hence it has no non-zero elements.
466    pub fn zero(shape: Shape) -> Self {
467        let (nrows, _ncols) = shape;
468        Self::new_trusted(
469            CSR,
470            shape,
471            vec![Iptr::zero(); nrows + 1],
472            Vec::new(),
473            Vec::new(),
474        )
475    }
476
477    /// Reserve the storage for the given additional number of nonzero data
478    pub fn reserve_outer_dim(&mut self, outer_dim_additional: usize) {
479        self.indptr.reserve(outer_dim_additional);
480    }
481
482    /// Reserve the storage for the given additional number of nonzero data
483    pub fn reserve_nnz(&mut self, nnz_additional: usize) {
484        self.indices.reserve(nnz_additional);
485        self.data.reserve(nnz_additional);
486    }
487
488    /// Reserve the storage for the given number of nonzero data
489    pub fn reserve_outer_dim_exact(&mut self, outer_dim_lim: usize) {
490        self.indptr.reserve_exact(outer_dim_lim + 1);
491    }
492
493    /// Reserve the storage for the given number of nonzero data
494    pub fn reserve_nnz_exact(&mut self, nnz_lim: usize) {
495        self.indices.reserve_exact(nnz_lim);
496        self.data.reserve_exact(nnz_lim);
497    }
498
499    /// Create a CSR matrix from a dense matrix, ignoring elements lower than `epsilon`.
500    ///
501    /// If epsilon is negative, it will be clamped to zero.
502    pub fn csr_from_dense(m: ArrayView<N, Ix2>, epsilon: N) -> Self
503    where
504        N: Num + Clone + cmp::PartialOrd + Signed,
505    {
506        let epsilon = if epsilon > N::zero() {
507            epsilon
508        } else {
509            N::zero()
510        };
511        let nrows = m.shape()[0];
512        let ncols = m.shape()[1];
513
514        let mut indptr = vec![Iptr::zero(); nrows + 1];
515        let mut nnz = 0;
516        for (row, row_count) in m.outer_iter().zip(&mut indptr[1..]) {
517            nnz += row.iter().filter(|&x| x.abs() > epsilon).count();
518            *row_count = Iptr::from_usize(nnz);
519        }
520
521        let mut indices = Vec::with_capacity(nnz);
522        let mut data = Vec::with_capacity(nnz);
523        for row in m.outer_iter() {
524            for (col_ind, x) in row.iter().enumerate() {
525                if x.abs() > epsilon {
526                    indices.push(I::from_usize(col_ind));
527                    data.push(x.clone());
528                }
529            }
530        }
531        Self {
532            storage: CompressedStorage::CSR,
533            nrows,
534            ncols,
535            indptr: crate::IndPtr::new_trusted(indptr),
536            indices,
537            data,
538        }
539    }
540
541    /// Create a CSC matrix from a dense matrix, ignoring elements lower than `epsilon`.
542    ///
543    /// If epsilon is negative, it will be clamped to zero.
544    pub fn csc_from_dense(m: ArrayView<N, Ix2>, epsilon: N) -> Self
545    where
546        N: Num + Clone + cmp::PartialOrd + Signed,
547    {
548        Self::csr_from_dense(m.reversed_axes(), epsilon).transpose_into()
549    }
550
551    /// Append an outer dim to an existing matrix, compressing it in the process
552    pub fn append_outer(self, data: &[N]) -> Self
553    where
554        N: Clone + Zero,
555    {
556        // Safety: enumerate is monotonically increasing
557        unsafe {
558            self.append_outer_iter_unchecked(
559                data.iter()
560                    .cloned()
561                    .enumerate()
562                    .filter(|(_, val)| !val.is_zero()),
563            )
564        }
565    }
566
567    /// Append an outer dim to an existing matrix, increasing the size along the outer
568    /// dimension by one.
569    ///
570    /// # Panics
571    ///
572    /// if the iterator index is **not** monotonically increasing
573    pub fn append_outer_iter<Iter>(self, iter: Iter) -> Self
574    where
575        N: Zero,
576        Iter: IntoIterator<Item = (usize, N)>,
577    {
578        let iter = iter.into_iter();
579        unsafe {
580            self.append_outer_iter_unchecked(AssertOrderedIterator {
581                prev: None,
582                iter: iter.filter(|(_, val)| !val.is_zero()),
583            })
584        }
585    }
586
587    /// Append an outer dim to an existing matrix, increasing the size along the outer
588    /// dimension by one.
589    ///
590    /// # Safety
591    ///
592    /// This is unsafe since indices for each inner dim should be monotonically increasing
593    /// which is not checked. The data values are additionally not checked for zero.
594    /// See `append_outer_iter` for the checked version
595    pub unsafe fn append_outer_iter_unchecked<Iter>(
596        mut self,
597        iter: Iter,
598    ) -> Self
599    where
600        Iter: IntoIterator<Item = (usize, N)>,
601    {
602        let iter = iter.into_iter();
603        if let (_, Some(nnz)) = iter.size_hint() {
604            self.reserve_nnz(nnz)
605        }
606        let mut nnz = self.nnz();
607        for (inner_ind, val) in iter {
608            self.indices.push(I::from_usize(inner_ind));
609            self.data.push(val);
610            nnz += 1;
611        }
612        if let Some(last_inner_ind) = self.indices.last() {
613            assert!(
614                last_inner_ind.index_unchecked() < self.inner_dims(),
615                "inner index out of range"
616            );
617        }
618        match self.storage {
619            CSR => self.nrows += 1,
620            CSC => self.ncols += 1,
621        }
622        self.indptr.push(Iptr::from_usize(nnz));
623        self
624    }
625
626    /// Append an outer dim to an existing matrix, provided by a sparse vector
627    pub fn append_outer_csvec(self, vec: CsVecViewI<N, I>) -> Self
628    where
629        N: Clone,
630    {
631        assert_eq!(self.inner_dims(), vec.dim());
632        // Safety: CsVec has monotonically increasing indices
633        unsafe {
634            self.append_outer_iter_unchecked(
635                vec.iter().map(|(i, val)| (i, val.clone())),
636            )
637        }
638    }
639
640    /// Insert an element in the matrix. If the element is already present,
641    /// its value is overwritten.
642    ///
643    /// Warning: this is not an efficient operation, as it requires
644    /// a non-constant lookup followed by two `Vec` insertions.
645    ///
646    /// The insertion will be efficient, however, if the elements are inserted
647    /// according to the matrix's order, eg following the row order for a CSR
648    /// matrix.
649    pub fn insert(&mut self, row: usize, col: usize, val: N) {
650        match self.storage() {
651            CSR => self.insert_outer_inner(row, col, val),
652            CSC => self.insert_outer_inner(col, row, val),
653        }
654    }
655
656    fn insert_outer_inner(
657        &mut self,
658        outer_ind: usize,
659        inner_ind: usize,
660        val: N,
661    ) {
662        let outer_dims = self.outer_dims();
663        let inner_ind_idx = I::from_usize(inner_ind);
664        if outer_ind >= outer_dims {
665            // we need to add a new outer dimension
666            let last_nnz = self.indptr.nnz_i();
667            self.indptr.resize(outer_ind + 1, last_nnz);
668            self.set_outer_dims(outer_ind + 1);
669            self.indptr.push(last_nnz + Iptr::one());
670            self.indices.push(inner_ind_idx);
671            self.data.push(val);
672        } else {
673            // we need to search for an insertion spot
674            let range = self.indptr.outer_inds_sz(outer_ind);
675            let location =
676                self.indices[range.clone()].binary_search(&inner_ind_idx);
677            match location {
678                Ok(ind) => {
679                    let ind = range.start + ind.index_unchecked();
680                    self.data[ind] = val;
681                    return;
682                }
683                Err(ind) => {
684                    let ind = range.start + ind.index_unchecked();
685                    self.indices.insert(ind, inner_ind_idx);
686                    self.data.insert(ind, val);
687                    self.indptr.record_new_element(outer_ind);
688                }
689            }
690        }
691
692        if inner_ind >= self.inner_dims() {
693            self.set_inner_dims(inner_ind + 1);
694        }
695    }
696
697    fn set_outer_dims(&mut self, outer_dims: usize) {
698        match self.storage() {
699            CSR => self.nrows = outer_dims,
700            CSC => self.ncols = outer_dims,
701        }
702    }
703
704    fn set_inner_dims(&mut self, inner_dims: usize) {
705        match self.storage() {
706            CSR => self.ncols = inner_dims,
707            CSC => self.nrows = inner_dims,
708        }
709    }
710}
711
712pub(crate) struct AssertOrderedIterator<Iter> {
713    prev: Option<usize>,
714    iter: Iter,
715}
716
717impl<N, Iter: Iterator<Item = (usize, N)>> Iterator
718    for AssertOrderedIterator<Iter>
719{
720    type Item = (usize, N);
721
722    fn next(&mut self) -> Option<Self::Item> {
723        let (idx, n) = self.iter.next()?;
724
725        if let Some(prev_idx) = self.prev {
726            assert!(
727                prev_idx < idx,
728                "index out of order. {} followed {}",
729                idx,
730                prev_idx
731            );
732        }
733        self.prev = Some(idx);
734        Some((idx, n))
735    }
736
737    fn size_hint(&self) -> (usize, Option<usize>) {
738        self.iter.size_hint()
739    }
740}
741
742/// # Constructor methods for sparse matrix views
743///
744/// These constructors can be used to create views over non-matrix data
745/// such as slices.
746impl<'a, N: 'a, I: 'a + SpIndex, Iptr: 'a + SpIndex>
747    CsMatViewI<'a, N, I, Iptr>
748{
749    /// Get a view into count contiguous outer dimensions, starting from i.
750    ///
751    /// eg this gets the rows from i to i + count in a CSR matrix
752    ///
753    /// This function is now deprecated, as using an index and a count is not
754    /// ergonomic. The replacement, `slice_outer`, leverages the
755    /// `std::ops::Range` family of types, which is better integrated into the
756    /// ecosystem.
757    #[deprecated(
758        since = "0.10.0",
759        note = "Please use the `slice_outer` method instead"
760    )]
761    pub fn middle_outer_views(
762        &self,
763        i: usize,
764        count: usize,
765    ) -> CsMatViewI<'a, N, I, Iptr> {
766        let iend = i.checked_add(count).unwrap();
767        let (nrows, ncols) = match self.storage {
768            CSR => (count, self.cols()),
769            CSC => (self.rows(), count),
770        };
771        let data_range = self.indptr.outer_inds_slice(i, iend);
772        CsMatViewI {
773            storage: self.storage,
774            nrows,
775            ncols,
776            indptr: self.indptr.middle_slice_rbr(i..iend),
777            indices: &self.indices[data_range.clone()],
778            data: &self.data[data_range],
779        }
780    }
781
782    /// Get an iterator that yields the non-zero locations and values stored in
783    /// this matrix, in the fastest iteration order.
784    ///
785    /// This method will yield the correct lifetime for iterating over a sparse
786    /// matrix view.
787    pub fn iter_rbr(&self) -> CsIter<'a, N, I, Iptr> {
788        CsIter {
789            storage: self.storage,
790            cur_outer: I::zero(),
791            indptr: self.indptr.reborrow(),
792            inner_iter: self.indices.iter().zip(self.data.iter()).enumerate(),
793        }
794    }
795}
796
797/// # Common methods for all variants of compressed sparse matrices.
798impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
799    CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr>
800where
801    I: SpIndex,
802    Iptr: SpIndex,
803    IptrStorage: Deref<Target = [Iptr]>,
804    IndStorage: Deref<Target = [I]>,
805    DataStorage: Deref<Target = [N]>,
806{
807    /// The underlying storage of this matrix
808    pub fn storage(&self) -> CompressedStorage {
809        self.storage
810    }
811
812    /// The number of rows of this matrix
813    pub fn rows(&self) -> usize {
814        self.nrows
815    }
816
817    /// The number of cols of this matrix
818    pub fn cols(&self) -> usize {
819        self.ncols
820    }
821
822    /// The shape of the matrix.
823    /// Equivalent to `let shape = (a.rows(), a.cols())`.
824    pub fn shape(&self) -> Shape {
825        (self.nrows, self.ncols)
826    }
827
828    /// The number of non-zero elements this matrix stores.
829    /// This is often relevant for the complexity of most sparse matrix
830    /// algorithms, which are often linear in the number of non-zeros.
831    pub fn nnz(&self) -> usize {
832        self.indptr.nnz()
833    }
834
835    /// The density of the sparse matrix, defined as the number of non-zero
836    /// elements divided by the maximum number of elements
837    pub fn density(&self) -> f64 {
838        let rows = self.nrows as f64;
839        let cols = self.ncols as f64;
840        let nnz = self.nnz() as f64;
841        nnz / (rows * cols)
842    }
843
844    /// Number of outer dimensions, that ie equal to `self.rows()` for a CSR
845    /// matrix, and equal to `self.cols()` for a CSC matrix
846    pub fn outer_dims(&self) -> usize {
847        outer_dimension(self.storage, self.nrows, self.ncols)
848    }
849
850    /// Number of inner dimensions, that ie equal to `self.cols()` for a CSR
851    /// matrix, and equal to `self.rows()` for a CSC matrix
852    pub fn inner_dims(&self) -> usize {
853        match self.storage {
854            CSC => self.nrows,
855            CSR => self.ncols,
856        }
857    }
858
859    /// Access the element located at row i and column j.
860    /// Will return None if there is no non-zero element at this location.
861    ///
862    /// This access is logarithmic in the number of non-zeros
863    /// in the corresponding outer slice. It is therefore advisable not to rely
864    /// on this for algorithms, and prefer [`outer_iterator`](Self::outer_iterator)
865    /// which accesses elements in storage order.
866    pub fn get(&self, i: usize, j: usize) -> Option<&N> {
867        match self.storage {
868            CSR => self.get_outer_inner(i, j),
869            CSC => self.get_outer_inner(j, i),
870        }
871    }
872
873    /// The array of offsets in the `indices()` `and data()` slices.
874    /// The elements of the slice at outer dimension i
875    /// are available between the elements `indptr\[i\]` and `indptr\[i+1\]`
876    /// in the `indices()` and `data()` slices.
877    ///
878    /// # Example
879    ///
880    /// ```rust
881    /// use sprs::{CsMat};
882    /// let eye : CsMat<f64> = CsMat::eye(5);
883    /// // get the element of row 3
884    /// // there is only one element in this row, with a column index of 3
885    /// // and a value of 1.
886    /// let range = eye.indptr().outer_inds_sz(3);
887    /// assert_eq!(range.start, 3);
888    /// assert_eq!(range.end, 4);
889    /// assert_eq!(eye.indices()[range.start], 3);
890    /// assert_eq!(eye.data()[range.start], 1.);
891    /// ```
892    pub fn indptr(&self) -> crate::IndPtrView<'_, Iptr> {
893        crate::IndPtrView::new_trusted(self.indptr.raw_storage())
894    }
895
896    /// Get an indptr representation suitable for ffi, cloning if necessary to
897    /// get a compatible representation.
898    ///
899    /// # Warning
900    ///
901    /// For ffi usage, one needs to call `Cow::as_ptr`, but it's important
902    /// to keep the `Cow` alive during the lifetime of the pointer. Example
903    /// of a correct and incorrect ffi usage:
904    ///
905    /// ```rust
906    /// let mat: sprs::CsMat<f64> = sprs::CsMat::eye(5);
907    /// let mid = mat.view().middle_outer_views(1, 2);
908    /// let ptr = {
909    ///     let indptr_proper = mid.proper_indptr();
910    ///     println!(
911    ///         "ptr {:?} is valid as long as _indptr_proper_owned is in scope",
912    ///         indptr_proper.as_ptr()
913    ///     );
914    ///     indptr_proper.as_ptr()
915    /// };
916    /// // This line is UB.
917    /// // println!("ptr deref: {}", *ptr);
918    /// ```
919    pub fn proper_indptr(&self) -> std::borrow::Cow<'_, [Iptr]> {
920        self.indptr.to_proper()
921    }
922
923    /// The inner dimension location for each non-zero value. See
924    /// the documentation of `indptr()` for more explanations.
925    pub fn indices(&self) -> &[I] {
926        &self.indices[..]
927    }
928
929    /// The non-zero values. See the documentation of `indptr()`
930    /// for more explanations.
931    pub fn data(&self) -> &[N] {
932        &self.data[..]
933    }
934
935    /// Destruct the matrix object and recycle its storage containers.
936    ///
937    /// # Example
938    ///
939    /// ```rust
940    /// use sprs::{CsMat};
941    /// let (indptr, indices, data) = CsMat::<i32>::eye(3).into_raw_storage();
942    /// assert_eq!(indptr, vec![0, 1, 2, 3]);
943    /// assert_eq!(indices, vec![0, 1, 2]);
944    /// assert_eq!(data, vec![1, 1, 1]);
945    /// ```
946    pub fn into_raw_storage(self) -> (IptrStorage, IndStorage, DataStorage) {
947        let Self {
948            indptr,
949            indices,
950            data,
951            ..
952        } = self;
953        (indptr.into_raw_storage(), indices, data)
954    }
955
956    /// Test whether the matrix is in CSC storage
957    pub fn is_csc(&self) -> bool {
958        self.storage == CSC
959    }
960
961    /// Test whether the matrix is in CSR storage
962    pub fn is_csr(&self) -> bool {
963        self.storage == CSR
964    }
965
966    /// Transpose a matrix in place
967    /// No allocation required (this is simply a storage order change)
968    pub fn transpose_mut(&mut self) {
969        mem::swap(&mut self.nrows, &mut self.ncols);
970        self.storage = self.storage.other_storage();
971    }
972
973    /// Transpose a matrix in place
974    /// No allocation required (this is simply a storage order change)
975    pub fn transpose_into(mut self) -> Self {
976        self.transpose_mut();
977        self
978    }
979
980    /// Transposed view of this matrix
981    /// No allocation required (this is simply a storage order change)
982    pub fn transpose_view(&self) -> CsMatViewI<'_, N, I, Iptr> {
983        CsMatViewI {
984            storage: self.storage.other_storage(),
985            nrows: self.ncols,
986            ncols: self.nrows,
987            indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
988            indices: &self.indices[..],
989            data: &self.data[..],
990        }
991    }
992
993    /// Get an owned version of this matrix. If the matrix was already
994    /// owned, this will make a deep copy.
995    pub fn to_owned(&self) -> CsMatI<N, I, Iptr>
996    where
997        N: Clone,
998    {
999        CsMatI {
1000            storage: self.storage,
1001            nrows: self.nrows,
1002            ncols: self.ncols,
1003            indptr: self.indptr.to_owned(),
1004            indices: self.indices.to_vec(),
1005            data: self.data.to_vec(),
1006        }
1007    }
1008
1009    /// Generate a one-hot matrix, compressing the inner dimension.
1010    ///
1011    /// Returns a matrix with the same size, the same CSR/CSC type,
1012    /// and a single value of 1.0 within each populated inner vector.
1013    ///
1014    /// See [`into_csc`](CsMatBase::into_csc) and [`into_csr`](CsMatBase::into_csr)
1015    /// if you need to prepare a matrix
1016    /// for one-hot compression.
1017    pub fn to_inner_onehot(&self) -> CsMatI<N, I, Iptr>
1018    where
1019        N: Clone + Float + PartialOrd,
1020    {
1021        let mut indptr_counter = 0_usize;
1022        let mut indptr: Vec<Iptr> = Vec::with_capacity(self.indptr.len());
1023
1024        let max_data_len = self.indptr.len().min(self.data.len());
1025        let mut indices: Vec<I> = Vec::with_capacity(max_data_len);
1026        let mut data = Vec::with_capacity(max_data_len);
1027
1028        for inner_vec in self.outer_iterator() {
1029            let hot_element = inner_vec
1030                .iter()
1031                .filter(|e| !e.1.is_nan())
1032                .max_by(|a, b| {
1033                    a.1.partial_cmp(b.1)
1034                        .expect("Unexpected NaN value was found")
1035                })
1036                .map(|a| a.0);
1037
1038            indptr.push(Iptr::from_usize(indptr_counter));
1039
1040            if let Some(inner_id) = hot_element {
1041                indices.push(I::from_usize(inner_id));
1042                data.push(N::one());
1043                indptr_counter += 1;
1044            }
1045        }
1046
1047        indptr.push(Iptr::from_usize(indptr_counter));
1048        CsMatI {
1049            storage: self.storage,
1050            nrows: self.rows(),
1051            ncols: self.cols(),
1052            indptr: crate::IndPtr::new_trusted(indptr),
1053            indices,
1054            data,
1055        }
1056    }
1057
1058    /// Clone the matrix with another integer type for indptr and indices
1059    ///
1060    /// # Panics
1061    ///
1062    /// If the indices or indptr values cannot be represented by the requested
1063    /// integer type.
1064    pub fn to_other_types<I2, N2, Iptr2>(&self) -> CsMatI<N2, I2, Iptr2>
1065    where
1066        N: Clone + Into<N2>,
1067        I2: SpIndex,
1068        Iptr2: SpIndex,
1069    {
1070        let indptr = crate::IndPtr::new_trusted(
1071            self.indptr
1072                .raw_storage()
1073                .iter()
1074                .map(|i| Iptr2::from_usize(i.index_unchecked()))
1075                .collect(),
1076        );
1077        let indices = self
1078            .indices
1079            .iter()
1080            .map(|i| I2::from_usize(i.index_unchecked()))
1081            .collect();
1082        let data = self.data.iter().map(|x| x.clone().into()).collect();
1083        CsMatI {
1084            storage: self.storage,
1085            nrows: self.nrows,
1086            ncols: self.ncols,
1087            indptr,
1088            indices,
1089            data,
1090        }
1091    }
1092
1093    /// Return a view into the current matrix
1094    pub fn view(&self) -> CsMatViewI<'_, N, I, Iptr> {
1095        CsMatViewI {
1096            storage: self.storage,
1097            nrows: self.nrows,
1098            ncols: self.ncols,
1099            indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
1100            indices: &self.indices[..],
1101            data: &self.data[..],
1102        }
1103    }
1104
1105    pub fn structure_view(&self) -> CsStructureViewI<'_, I, Iptr> {
1106        // Safety: std::slice::from_raw_parts requires its passed
1107        // pointer to be valid for the whole length of the slice. We have a
1108        // zero-sized type, so the length is zero, and since we cast
1109        // a non-null pointer, the pointer is valid as all pointers to zero-sized
1110        // types are valid if they are not null.
1111        let zst_data = unsafe {
1112            std::slice::from_raw_parts(
1113                self.data.as_ptr().cast::<()>(),
1114                self.data.len(),
1115            )
1116        };
1117        CsStructureViewI {
1118            storage: self.storage,
1119            nrows: self.nrows,
1120            ncols: self.ncols,
1121            indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
1122            indices: &self.indices[..],
1123            data: zst_data,
1124        }
1125    }
1126
1127    pub fn to_dense(&self) -> Array<N, Ix2>
1128    where
1129        N: Clone + Zero,
1130    {
1131        let mut res = Array::zeros((self.rows(), self.cols()));
1132        assign_to_dense(res.view_mut(), self.view());
1133        res
1134    }
1135
1136    /// Return an outer iterator for the matrix
1137    ///
1138    /// This can be used for iterating over the rows (resp. cols) of
1139    /// a CSR (resp. CSC) matrix.
1140    ///
1141    /// ```rust
1142    /// use sprs::{CsMat};
1143    /// let eye = CsMat::eye(5);
1144    /// for (row_ind, row_vec) in eye.outer_iterator().enumerate() {
1145    ///     let (col_ind, &val): (_, &f64) = row_vec.iter().next().unwrap();
1146    ///     assert_eq!(row_ind, col_ind);
1147    ///     assert_eq!(val, 1.);
1148    /// }
1149    /// ```
1150    pub fn outer_iterator(
1151        &self,
1152    ) -> impl std::iter::DoubleEndedIterator<Item = CsVecViewI<'_, N, I>>
1153           + std::iter::ExactSizeIterator<Item = CsVecViewI<'_, N, I>>
1154           + '_ {
1155        self.indptr.iter_outer_sz().map(move |range| {
1156            CsVecViewI::new_trusted(
1157                self.inner_dims(),
1158                // TODO: unsafe slice indexing
1159                &self.indices[range.clone()],
1160                &self.data[range],
1161            )
1162        })
1163    }
1164
1165    /// Return an outer iterator over P*A*P^T, where it is necessary to use
1166    /// `CsVec::iter_perm(perm.inv())` to iterate over the inner dimension.
1167    /// Unstable, this is a convenience function for the crate `sprs-ldl`
1168    /// for now.
1169    #[doc(hidden)]
1170    pub fn outer_iterator_papt<'a, 'perm: 'a>(
1171        &'a self,
1172        perm: PermViewI<'perm, I>,
1173    ) -> impl std::iter::DoubleEndedIterator<Item = (usize, CsVecViewI<'a, N, I>)>
1174           + std::iter::ExactSizeIterator<Item = (usize, CsVecViewI<'a, N, I>)>
1175           + 'a {
1176        (0..self.outer_dims()).map(move |outer_ind| {
1177            let outer_ind_perm = perm.at(outer_ind);
1178            let range = self.indptr.outer_inds_sz(outer_ind_perm);
1179            let indices = &self.indices[range.clone()];
1180            let data = &self.data[range];
1181            // CsMat invariants imply CsVec invariants
1182            let vec = CsVecBase::new_trusted(self.inner_dims(), indices, data);
1183            (outer_ind_perm, vec)
1184        })
1185    }
1186
1187    /// Get the max number of nnz for each outer dim
1188    pub fn max_outer_nnz(&self) -> usize {
1189        self.outer_iterator()
1190            .map(|outer| outer.indices().len())
1191            .max()
1192            .unwrap_or(0)
1193    }
1194
1195    /// Get the degrees of each vertex on a symmetric matrix
1196    ///
1197    /// The nonzero pattern of a symmetric matrix can be interpreted as
1198    /// an undirected graph. In such a graph, a vertex i is connected to another
1199    /// vertex j if there is a corresponding nonzero entry in the matrix at
1200    /// location (i, j).
1201    ///
1202    /// This function returns a vector containing the degree of each vertex,
1203    /// that is to say the number of neighbor of each vertex. We do not
1204    /// count diagonal entries as a neighbor.
1205    pub fn degrees(&self) -> Vec<usize> {
1206        self.outer_iterator()
1207            .enumerate()
1208            .map(|(outer_dim, outer)| {
1209                outer
1210                    .indices()
1211                    .iter()
1212                    .filter(|ind| ind.index() != outer_dim)
1213                    .count()
1214            })
1215            .collect()
1216    }
1217
1218    /// Get a view into the i-th outer dimension (eg i-th row for a CSR matrix)
1219    pub fn outer_view(&self, i: usize) -> Option<CsVecViewI<'_, N, I>> {
1220        if i >= self.outer_dims() {
1221            return None;
1222        }
1223        let range = self.indptr.outer_inds_sz(i);
1224        // CsMat invariants imply CsVec invariants
1225        Some(CsVecViewI::new_trusted(
1226            self.inner_dims(),
1227            // TODO: unsafe slice indexing
1228            &self.indices[range.clone()],
1229            &self.data[range],
1230        ))
1231    }
1232
1233    /// Get the diagonal of a sparse matrix
1234    pub fn diag(&self) -> CsVecI<N, I>
1235    where
1236        N: Clone,
1237    {
1238        let shape = self.shape();
1239        let smallest_dim: usize = cmp::min(shape.0, shape.1);
1240        // Assuming most matrices have dense diagonals, it seems prudent
1241        // to allocate a bit of memory up front
1242        let heuristic = smallest_dim / 2;
1243        let mut index_vec = Vec::with_capacity(heuristic);
1244        let mut data_vec = Vec::with_capacity(heuristic);
1245
1246        for i in 0..smallest_dim {
1247            let optional_index = self.nnz_index(i, i);
1248            if let Some(idx) = optional_index {
1249                data_vec.push(self[idx].clone());
1250                index_vec.push(I::from_usize(i));
1251            }
1252        }
1253        data_vec.shrink_to_fit();
1254        index_vec.shrink_to_fit();
1255        CsVecI::new_trusted(smallest_dim, index_vec, data_vec)
1256    }
1257
1258    /// Iteration over all entries on the diagonal
1259    pub fn diag_iter(
1260        &self,
1261    ) -> impl ExactSizeIterator<Item = Option<&N>>
1262           + DoubleEndedIterator<Item = Option<&N>> {
1263        let smallest_dim = cmp::min(self.ncols, self.nrows);
1264        (0..smallest_dim).map(move |i| self.get_outer_inner(i, i))
1265    }
1266
1267    /// Iteration on outer blocks of size `block_size`
1268    ///
1269    /// # Panics
1270    ///
1271    /// If the block size is 0.
1272    pub fn outer_block_iter(
1273        &self,
1274        block_size: usize,
1275    ) -> impl std::iter::DoubleEndedIterator<Item = CsMatViewI<'_, N, I, Iptr>>
1276           + std::iter::ExactSizeIterator<Item = CsMatViewI<'_, N, I, Iptr>>
1277           + '_ {
1278        (0..self.outer_dims()).step_by(block_size).map(move |i| {
1279            let count = if i + block_size > self.outer_dims() {
1280                self.outer_dims() - i
1281            } else {
1282                block_size
1283            };
1284            self.view().slice_outer_rbr(i..i + count)
1285        })
1286    }
1287
1288    /// Return a new sparse matrix with the same sparsity pattern, with all non-zero values mapped by the function `f`.
1289    pub fn map<F, N2>(&self, f: F) -> CsMatI<N2, I, Iptr>
1290    where
1291        F: FnMut(&N) -> N2,
1292    {
1293        let data: Vec<N2> = self.data.iter().map(f).collect();
1294
1295        CsMatI {
1296            storage: self.storage,
1297            nrows: self.nrows,
1298            ncols: self.ncols,
1299            indptr: self.indptr.to_owned(),
1300            indices: self.indices.to_vec(),
1301            data,
1302        }
1303    }
1304
1305    /// Access an element given its `outer_ind` and `inner_ind`.
1306    /// Will return None if there is no non-zero element at this location.
1307    ///
1308    /// This access is logarithmic in the number of non-zeros
1309    /// in the corresponding outer slice. It is therefore advisable not to rely
1310    /// on this for algorithms, and prefer [`outer_iterator`](Self::outer_iterator)
1311    /// which accesses elements in storage order.
1312    pub fn get_outer_inner(
1313        &self,
1314        outer_ind: usize,
1315        inner_ind: usize,
1316    ) -> Option<&N> {
1317        self.outer_view(outer_ind)
1318            .and_then(|vec| vec.get_rbr(inner_ind))
1319    }
1320
1321    /// Find the non-zero index of the element specified by row and col
1322    ///
1323    /// Searching this index is logarithmic in the number of non-zeros
1324    /// in the corresponding outer slice.
1325    /// Once it is available, the `NnzIndex` enables retrieving the data with
1326    /// O(1) complexity.
1327    pub fn nnz_index(&self, row: usize, col: usize) -> Option<NnzIndex> {
1328        match self.storage() {
1329            CSR => self.nnz_index_outer_inner(row, col),
1330            CSC => self.nnz_index_outer_inner(col, row),
1331        }
1332    }
1333
1334    /// Find the non-zero index of the element specified by `outer_ind` and
1335    /// `inner_ind`.
1336    ///
1337    /// Searching this index is logarithmic in the number of non-zeros
1338    /// in the corresponding outer slice.
1339    pub fn nnz_index_outer_inner(
1340        &self,
1341        outer_ind: usize,
1342        inner_ind: usize,
1343    ) -> Option<NnzIndex> {
1344        if outer_ind >= self.outer_dims() {
1345            return None;
1346        }
1347        let offset = self.indptr.outer_inds_sz(outer_ind).start;
1348        self.outer_view(outer_ind)
1349            .and_then(|vec| vec.nnz_index(inner_ind))
1350            .map(|vec::NnzIndex(ind)| NnzIndex(ind + offset))
1351    }
1352
1353    /// Check the structure of `CsMat` components
1354    /// This will ensure that:
1355    /// * indptr is of length `outer_dim() + 1`
1356    /// * indices and data have the same length, `nnz == indptr[outer_dims()]`
1357    /// * indptr is sorted
1358    /// * indptr values do not exceed [`usize::MAX`](usize::MAX)`/ 2`, as that would mean
1359    ///   indices and indptr would take more space than the addressable memory
1360    /// * indices is sorted for each outer slice
1361    /// * indices are lower than `inner_dims()`
1362    pub fn check_compressed_structure(&self) -> Result<(), StructureError> {
1363        let inner = self.inner_dims();
1364        let outer = self.outer_dims();
1365
1366        if self.indices.len() != self.data.len() {
1367            return Err(StructureError::SizeMismatch(
1368                "Indices and data lengths do not match",
1369            ));
1370        }
1371
1372        utils::check_compressed_structure(
1373            inner,
1374            outer,
1375            self.indptr.raw_storage(),
1376            &self.indices,
1377        )
1378    }
1379
1380    /// Get an iterator that yields the non-zero locations and values stored in
1381    /// this matrix, in the fastest iteration order.
1382    pub fn iter(&self) -> CsIter<'_, N, I, Iptr> {
1383        CsIter {
1384            storage: self.storage,
1385            cur_outer: I::zero(),
1386            indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
1387            inner_iter: self.indices.iter().zip(self.data.iter()).enumerate(),
1388        }
1389    }
1390}
1391
1392/// # Methods to convert between storage orders
1393impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
1394    CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr>
1395where
1396    N: Default,
1397    I: SpIndex,
1398    Iptr: SpIndex,
1399    IptrStorage: Deref<Target = [Iptr]>,
1400    IndStorage: Deref<Target = [I]>,
1401    DataStorage: Deref<Target = [N]>,
1402{
1403    /// Create a matrix mathematically equal to this one, but with the
1404    /// opposed storage (a CSC matrix will be converted to CSR, and vice versa)
1405    pub fn to_other_storage(&self) -> CsMatI<N, I, Iptr>
1406    where
1407        N: Clone,
1408    {
1409        let mut indptr = vec![Iptr::zero(); self.inner_dims() + 1];
1410        let mut indices = vec![I::zero(); self.nnz()];
1411        let mut data = vec![N::default(); self.nnz()];
1412        raw::convert_mat_storage(
1413            self.view(),
1414            &mut indptr,
1415            &mut indices,
1416            &mut data,
1417        );
1418        CsMatI {
1419            storage: self.storage().other_storage(),
1420            nrows: self.nrows,
1421            ncols: self.ncols,
1422            indptr: crate::IndPtr::new_trusted(indptr),
1423            indices,
1424            data,
1425        }
1426    }
1427
1428    /// Create a new CSC matrix equivalent to this one.
1429    /// A new matrix will be created even if this matrix was already CSC.
1430    pub fn to_csc(&self) -> CsMatI<N, I, Iptr>
1431    where
1432        N: Clone,
1433    {
1434        match self.storage {
1435            CSR => self.to_other_storage(),
1436            CSC => self.to_owned(),
1437        }
1438    }
1439
1440    /// Create a new CSR matrix equivalent to this one.
1441    /// A new matrix will be created even if this matrix was already CSR.
1442    pub fn to_csr(&self) -> CsMatI<N, I, Iptr>
1443    where
1444        N: Clone,
1445    {
1446        match self.storage {
1447            CSR => self.to_owned(),
1448            CSC => self.to_other_storage(),
1449        }
1450    }
1451}
1452
1453impl<N, I, Iptr> CsMatI<N, I, Iptr>
1454where
1455    N: Default,
1456
1457    I: SpIndex,
1458    Iptr: SpIndex,
1459{
1460    /// Create a new CSC matrix equivalent to this one.
1461    /// If this matrix is CSR, it is converted to CSC
1462    /// If this matrix is CSC, it is returned by value
1463    pub fn into_csc(self) -> Self
1464    where
1465        N: Clone,
1466    {
1467        match self.storage {
1468            CSR => self.to_other_storage(),
1469            CSC => self,
1470        }
1471    }
1472
1473    /// Create a new CSR matrix equivalent to this one.
1474    /// If this matrix is CSC, it is converted to CSR
1475    /// If this matrix is CSR, it is returned by value
1476    pub fn into_csr(self) -> Self
1477    where
1478        N: Clone,
1479    {
1480        match self.storage {
1481            CSR => self,
1482            CSC => self.to_other_storage(),
1483        }
1484    }
1485}
1486
1487/// # Methods for sparse matrices holding mutable access to their values.
1488impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
1489    CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr>
1490where
1491    I: SpIndex,
1492    Iptr: SpIndex,
1493    IptrStorage: Deref<Target = [Iptr]>,
1494    IndStorage: Deref<Target = [I]>,
1495    DataStorage: DerefMut<Target = [N]>,
1496{
1497    /// Mutable access to the non zero values
1498    ///
1499    /// This enables changing the values without changing the matrix's
1500    /// structure. To also change the matrix's structure,
1501    /// see [modify](fn.modify.html)
1502    pub fn data_mut(&mut self) -> &mut [N] {
1503        &mut self.data[..]
1504    }
1505
1506    /// Sparse matrix self-multiplication by a scalar
1507    pub fn scale(&mut self, val: N)
1508    where
1509        for<'r> N: MulAssign<&'r N>,
1510    {
1511        for data in self.data_mut() {
1512            *data *= &val;
1513        }
1514    }
1515
1516    /// Get a mutable view into the i-th outer dimension
1517    /// (eg i-th row for a CSR matrix)
1518    pub fn outer_view_mut(
1519        &mut self,
1520        i: usize,
1521    ) -> Option<CsVecViewMutI<'_, N, I>> {
1522        if i >= self.outer_dims() {
1523            return None;
1524        }
1525        let range = self.indptr.outer_inds_sz(i);
1526        // CsMat invariants imply CsVec invariants
1527        Some(CsVecBase::new_trusted(
1528            self.inner_dims(),
1529            &self.indices[range.clone()],
1530            &mut self.data[range],
1531        ))
1532    }
1533
1534    /// Get a mutable reference to the element located at row i and column j.
1535    /// Will return None if there is no non-zero element at this location.
1536    ///
1537    /// This access is logarithmic in the number of non-zeros
1538    /// in the corresponding outer slice. It is therefore advisable not to rely
1539    /// on this for algorithms, and prefer [`outer_iterator_mut`](Self::outer_iterator_mut)
1540    /// which accesses elements in storage order.
1541    /// TODO: `outer_iterator_mut` is not yet implemented
1542    pub fn get_mut(&mut self, i: usize, j: usize) -> Option<&mut N> {
1543        match self.storage {
1544            CSR => self.get_outer_inner_mut(i, j),
1545            CSC => self.get_outer_inner_mut(j, i),
1546        }
1547    }
1548
1549    /// Get a mutable reference to an element given its `outer_ind` and `inner_ind`.
1550    /// Will return None if there is no non-zero element at this location.
1551    ///
1552    /// This access is logarithmic in the number of non-zeros
1553    /// in the corresponding outer slice. It is therefore advisable not to rely
1554    /// on this for algorithms, and prefer [`outer_iterator_mut`](Self::outer_iterator_mut)
1555    /// which accesses elements in storage order.
1556    pub fn get_outer_inner_mut(
1557        &mut self,
1558        outer_ind: usize,
1559        inner_ind: usize,
1560    ) -> Option<&mut N> {
1561        if let Some(NnzIndex(index)) =
1562            self.nnz_index_outer_inner(outer_ind, inner_ind)
1563        {
1564            Some(&mut self.data[index])
1565        } else {
1566            None
1567        }
1568    }
1569
1570    /// Set the value of the non-zero element located at (row, col)
1571    ///
1572    /// # Panics
1573    ///
1574    /// - on out-of-bounds access
1575    /// - if no non-zero element exists at the given location
1576    pub fn set(&mut self, row: usize, col: usize, val: N) {
1577        let outer = outer_dimension(self.storage(), row, col);
1578        let inner = inner_dimension(self.storage(), row, col);
1579        let vec::NnzIndex(index) = self
1580            .outer_view(outer)
1581            .and_then(|vec| vec.nnz_index(inner))
1582            .unwrap();
1583        self.data[index] = val;
1584    }
1585
1586    /// Apply a function to every non-zero element
1587    pub fn map_inplace<F>(&mut self, mut f: F)
1588    where
1589        F: FnMut(&N) -> N,
1590    {
1591        for val in &mut self.data[..] {
1592            *val = f(val);
1593        }
1594    }
1595
1596    /// Return a mutable outer iterator for the matrix
1597    ///
1598    /// This iterator yields mutable sparse vector views for each outer
1599    /// dimension. Only the non-zero values can be modified, the
1600    /// structure is kept immutable.
1601    pub fn outer_iterator_mut(
1602        &mut self,
1603    ) -> impl std::iter::DoubleEndedIterator<Item = CsVecViewMutI<'_, N, I>>
1604           + std::iter::ExactSizeIterator<Item = CsVecViewMutI<'_, N, I>>
1605           + '_ {
1606        let inner_dim = self.inner_dims();
1607        let indices = &self.indices[..];
1608        let data_ptr: *mut N = self.data.as_mut_ptr();
1609        self.indptr.iter_outer_sz().map(move |range| {
1610            // # Safety
1611            // * ranges always point to exclusive parts of data
1612            // * lifetime bound to &mut self
1613            let data: &mut [N] = unsafe {
1614                std::slice::from_raw_parts_mut(
1615                    data_ptr.add(range.start),
1616                    range.end - range.start,
1617                )
1618            };
1619
1620            CsVecViewMutI::new_trusted(inner_dim, &indices[range], data)
1621        })
1622    }
1623
1624    /// Return a mutable view into the current matrix
1625    pub fn view_mut(&mut self) -> CsMatViewMutI<'_, N, I, Iptr> {
1626        CsMatViewMutI {
1627            storage: self.storage,
1628            nrows: self.nrows,
1629            ncols: self.ncols,
1630            indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
1631            indices: &self.indices[..],
1632            data: &mut self.data[..],
1633        }
1634    }
1635
1636    /// Iteration over all entries on the diagonal
1637    pub fn diag_iter_mut(
1638        &mut self,
1639    ) -> impl ExactSizeIterator<Item = Option<&mut N>>
1640           + DoubleEndedIterator<Item = Option<&mut N>>
1641           + '_ {
1642        let data_ptr: *mut N = self.data[..].as_mut_ptr();
1643        let smallest_dim = cmp::min(self.ncols, self.nrows);
1644        (0..smallest_dim).map(move |i| {
1645            let idx = self.nnz_index_outer_inner(i, i);
1646            if let Some(NnzIndex(idx)) = idx {
1647                // To obtain multiple mutable references to different
1648                // locations in data we must use a pointer and some unsafe.
1649                // # Safety
1650                // This is safe as
1651                // * NnzIndex provides bounds checking
1652                // * diagonal entries are never overlapping in memory
1653                // * no entries are requested more than once
1654                // * nnz_index_outer_inner does not modify or read from entries in self.data
1655                Some(unsafe { &mut *data_ptr.add(idx) })
1656            } else {
1657                None
1658            }
1659        })
1660    }
1661}
1662
1663impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
1664    CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr>
1665where
1666    I: SpIndex,
1667    Iptr: SpIndex,
1668    IptrStorage: DerefMut<Target = [Iptr]>,
1669    IndStorage: DerefMut<Target = [I]>,
1670    DataStorage: DerefMut<Target = [N]>,
1671{
1672    /// Modify the matrix's structure without changing its nonzero count.
1673    ///
1674    /// The coherence of the structure will be checked afterwards.
1675    ///
1676    /// # Panics
1677    ///
1678    /// If the resulting matrix breaks the `CsMat` invariants
1679    /// (sorted indices, no out of bounds indices).
1680    ///
1681    /// # Example
1682    ///
1683    /// ```rust
1684    /// use sprs::CsMat;
1685    /// // |   1   |
1686    /// // | 1     |
1687    /// // |   1 1 |
1688    /// let mut mat = CsMat::new_csc((3, 3),
1689    ///                                   vec![0, 1, 3, 4],
1690    ///                                   vec![1, 0, 2, 2],
1691    ///                                   vec![1.; 4]);
1692    ///
1693    /// // | 1 2   |
1694    /// // | 1     |
1695    /// // |   1   |
1696    /// mat.modify(|indptr, indices, data| {
1697    ///     indptr[1] = 2;
1698    ///     indptr[2] = 4;
1699    ///     indices[0] = 0;
1700    ///     indices[1] = 1;
1701    ///     indices[2] = 0;
1702    ///     data[2] = 2.;
1703    /// });
1704    /// ```
1705    pub fn modify<F>(&mut self, mut f: F)
1706    where
1707        F: FnMut(&mut [Iptr], &mut [I], &mut [N]),
1708    {
1709        f(
1710            self.indptr.raw_storage_mut(),
1711            &mut self.indices[..],
1712            &mut self.data[..],
1713        );
1714        // This is safe as long as we do the check, if we panic
1715        // the structure can not be retrieved, as &mut self can not pass
1716        // safely across an unwind boundary
1717        self.check_compressed_structure().unwrap();
1718    }
1719}
1720
1721/// Raw functions acting directly on the compressed structure.
1722pub mod raw {
1723    use crate::indexing::SpIndex;
1724    use crate::sparse::prelude::*;
1725    use std::mem::swap;
1726
1727    /*
1728        /// Copy-convert a compressed matrix into the oppposite storage.
1729        ///
1730        /// The input compressed matrix does not need to have its indices sorted,
1731        /// but the output compressed matrix will have its indices sorted.
1732        ///
1733        /// Can be used to implement CSC <-> CSR conversions, or to implement
1734        /// same-storage (copy) transposition.
1735        ///
1736        /// # Panics
1737        ///
1738        /// Panics if indptr contains non-zero values
1739        ///
1740        /// Panics if the output slices don't match the input matrices'
1741        /// corresponding slices.
1742        pub fn convert_storage<N, I>(
1743            in_storage: super::CompressedStorage,
1744            shape: Shape,
1745            in_indtpr: &[I],
1746            in_indices: &[I],
1747            in_data: &[N],
1748            indptr: &mut [I],
1749            indices: &mut [I],
1750            data: &mut [N],
1751        ) where
1752            N: Clone,
1753            I: SpIndex,
1754        {
1755            // we're building a csmat even though the indices are not sorted,
1756            // but it's not a problem since we don't rely on this property.
1757            // FIXME: this would be better with an explicit unsorted matrix type
1758            let mat = CsMatBase {
1759                storage: in_storage,
1760                nrows: shape.0,
1761                ncols: shape.1,
1762                indptr: in_indtpr,
1763                indices: in_indices,
1764                data: in_data,
1765            };
1766
1767            convert_mat_storage(mat, indptr, indices, data);
1768        }
1769    */
1770
1771    /// Copy-convert a csmat into the oppposite storage.
1772    ///
1773    /// Can be used to implement CSC <-> CSR conversions, or to implement
1774    /// same-storage (copy) transposition.
1775    ///
1776    /// # Panics
1777    ///
1778    /// Panics if indptr contains non-zero values
1779    ///
1780    /// Panics if the output slices don't match the input matrices'
1781    /// corresponding slices.
1782    pub fn convert_mat_storage<N: Clone, I: SpIndex, Iptr: SpIndex>(
1783        mat: CsMatViewI<N, I, Iptr>,
1784        indptr: &mut [Iptr],
1785        indices: &mut [I],
1786        data: &mut [N],
1787    ) {
1788        assert_eq!(indptr.len(), mat.inner_dims() + 1);
1789        assert_eq!(indices.len(), mat.indices().len());
1790        assert_eq!(data.len(), mat.data().len());
1791
1792        assert!(indptr.iter().all(num_traits::Zero::is_zero));
1793
1794        assert!(
1795            I::try_from_usize(mat.rows()).is_some(),
1796            "Index type is not large enough to hold the number of rows requested (I::max_value={:?} vs. required {})", I::max_value(), mat.rows(),
1797        );
1798
1799        for vec in mat.outer_iterator() {
1800            for (inner_dim, _) in vec.iter() {
1801                indptr[inner_dim] += Iptr::one();
1802            }
1803        }
1804
1805        let mut cumsum = Iptr::zero();
1806        for iptr in indptr.iter_mut() {
1807            let tmp = *iptr;
1808            *iptr = cumsum;
1809            cumsum += tmp;
1810        }
1811        if let Some(last_iptr) = indptr.last() {
1812            assert_eq!(last_iptr.index(), mat.nnz());
1813        }
1814
1815        for (outer_dim, vec) in mat.outer_iterator().enumerate() {
1816            let outer_dim = I::from_usize_unchecked(outer_dim);
1817            for (inner_dim, val) in vec.iter() {
1818                let dest = indptr[inner_dim].index();
1819                data[dest] = val.clone();
1820                indices[dest] = outer_dim;
1821                indptr[inner_dim] += Iptr::one();
1822            }
1823        }
1824
1825        let mut last = Iptr::zero();
1826        for iptr in indptr.iter_mut() {
1827            swap(iptr, &mut last);
1828        }
1829    }
1830}
1831
1832impl<I, Iptr, IpStorage, IStorage, DStorage, T> std::ops::MulAssign<T>
1833    for CsMatBase<T, I, IpStorage, IStorage, DStorage, Iptr>
1834where
1835    I: SpIndex,
1836    Iptr: SpIndex,
1837    IpStorage: Deref<Target = [Iptr]>,
1838    IStorage: Deref<Target = [I]>,
1839    DStorage: DerefMut<Target = [T]>,
1840    T: std::ops::MulAssign<T> + Clone,
1841{
1842    fn mul_assign(&mut self, rhs: T) {
1843        self.data_mut()
1844            .iter_mut()
1845            .for_each(|v| v.mul_assign(rhs.clone()));
1846    }
1847}
1848
1849impl<I, Iptr, IpStorage, IStorage, DStorage, T> std::ops::DivAssign<T>
1850    for CsMatBase<T, I, IpStorage, IStorage, DStorage, Iptr>
1851where
1852    I: SpIndex,
1853    Iptr: SpIndex,
1854    IpStorage: Deref<Target = [Iptr]>,
1855    IStorage: Deref<Target = [I]>,
1856    DStorage: DerefMut<Target = [T]>,
1857    T: std::ops::DivAssign<T> + Clone,
1858{
1859    fn div_assign(&mut self, rhs: T) {
1860        self.data_mut()
1861            .iter_mut()
1862            .for_each(|v| v.div_assign(rhs.clone()));
1863    }
1864}
1865
1866impl<'a, 'b, N, I, Iptr, IpS1, IS1, DS1, IpS2, IS2, DS2>
1867    Mul<&'b CsMatBase<N, I, IpS2, IS2, DS2, Iptr>>
1868    for &'a CsMatBase<N, I, IpS1, IS1, DS1, Iptr>
1869where
1870    N: 'a + Clone + crate::MulAcc + num_traits::Zero + Default + Send + Sync,
1871    I: 'a + SpIndex,
1872    Iptr: 'a + SpIndex,
1873    IpS1: 'a + Deref<Target = [Iptr]>,
1874    IS1: 'a + Deref<Target = [I]>,
1875    DS1: 'a + Deref<Target = [N]>,
1876    IpS2: 'b + Deref<Target = [Iptr]>,
1877    IS2: 'b + Deref<Target = [I]>,
1878    DS2: 'b + Deref<Target = [N]>,
1879{
1880    type Output = CsMatI<N, I, Iptr>;
1881
1882    fn mul(
1883        self,
1884        rhs: &'b CsMatBase<N, I, IpS2, IS2, DS2, Iptr>,
1885    ) -> CsMatI<N, I, Iptr> {
1886        csmat_mul_csmat(self, rhs)
1887    }
1888}
1889
1890/// Multiply two sparse matrices.
1891///
1892/// This function is generic over `MulAcc`, and supports accumulating
1893/// into a different output type. This is not the default for `Mul`,
1894/// as type inference fails for intermediaries
1895pub fn csmat_mul_csmat<
1896    'a,
1897    'b,
1898    N,
1899    A,
1900    B,
1901    I,
1902    Iptr,
1903    IpS1,
1904    IS1,
1905    DS1,
1906    IpS2,
1907    IS2,
1908    DS2,
1909>(
1910    lhs: &'a CsMatBase<A, I, IpS1, IS1, DS1, Iptr>,
1911    rhs: &'b CsMatBase<B, I, IpS2, IS2, DS2, Iptr>,
1912) -> CsMatI<N, I, Iptr>
1913where
1914    N: 'a
1915        + Clone
1916        + crate::MulAcc<A, B>
1917        + crate::MulAcc<B, A>
1918        + num_traits::Zero
1919        + Default
1920        + Send
1921        + Sync,
1922    A: 'a + Clone + num_traits::Zero + Default + Send + Sync,
1923    B: 'a + Clone + num_traits::Zero + Default + Send + Sync,
1924    I: 'a + SpIndex,
1925    Iptr: 'a + SpIndex,
1926    IpS1: 'a + Deref<Target = [Iptr]>,
1927    IS1: 'a + Deref<Target = [I]>,
1928    DS1: 'a + Deref<Target = [A]>,
1929    IpS2: 'b + Deref<Target = [Iptr]>,
1930    IS2: 'b + Deref<Target = [I]>,
1931    DS2: 'b + Deref<Target = [B]>,
1932{
1933    match (lhs.storage(), rhs.storage()) {
1934        (CSR, CSR) => smmp::mul_csr_csr(lhs.view(), rhs.view()),
1935        (CSR, CSC) => {
1936            let rhs_csr = rhs.to_other_storage();
1937            smmp::mul_csr_csr(lhs.view(), rhs_csr.view())
1938        }
1939        (CSC, CSR) => {
1940            let rhs_csc = rhs.to_other_storage();
1941            smmp::mul_csr_csr(rhs_csc.transpose_view(), lhs.transpose_view())
1942                .transpose_into()
1943        }
1944        (CSC, CSC) => {
1945            smmp::mul_csr_csr(rhs.transpose_view(), lhs.transpose_view())
1946                .transpose_into()
1947        }
1948    }
1949}
1950
1951impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Add<&'b ArrayBase<DS2, Ix2>>
1952    for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
1953where
1954    N: 'a + Copy + Num + Default,
1955    for<'r> &'r N: Mul<Output = N>,
1956    I: 'a + SpIndex,
1957    Iptr: 'a + SpIndex,
1958    IpS: 'a + Deref<Target = [Iptr]>,
1959    IS: 'a + Deref<Target = [I]>,
1960    DS: 'a + Deref<Target = [N]>,
1961    DS2: 'b + ndarray::Data<Elem = N>,
1962{
1963    type Output = Array<N, Ix2>;
1964
1965    fn add(self, rhs: &'b ArrayBase<DS2, Ix2>) -> Array<N, Ix2> {
1966        let is_standard_layout =
1967            utils::fastest_axis(rhs.view()) == ndarray::Axis(1);
1968        let neuter_element = N::one();
1969        match (self.storage(), is_standard_layout) {
1970            (CSR, true) | (CSC, false) => binop::add_dense_mat_same_ordering(
1971                self,
1972                rhs,
1973                neuter_element,
1974                neuter_element,
1975            ),
1976            (CSR, false) | (CSC, true) => {
1977                let lhs = self.to_other_storage();
1978                binop::add_dense_mat_same_ordering(
1979                    &lhs,
1980                    rhs,
1981                    neuter_element,
1982                    neuter_element,
1983                )
1984            }
1985        }
1986    }
1987}
1988
1989impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Mul<&'b ArrayBase<DS2, Ix2>>
1990    for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
1991where
1992    N: 'a + crate::MulAcc + num_traits::Zero + Clone,
1993    I: 'a + SpIndex,
1994    Iptr: 'a + SpIndex,
1995    IpS: 'a + Deref<Target = [Iptr]>,
1996    IS: 'a + Deref<Target = [I]>,
1997    DS: 'a + Deref<Target = [N]>,
1998    DS2: 'b + ndarray::Data<Elem = N>,
1999{
2000    type Output = Array<N, Ix2>;
2001
2002    fn mul(self, rhs: &'b ArrayBase<DS2, Ix2>) -> Array<N, Ix2> {
2003        let rows = self.rows();
2004        let cols = rhs.shape()[1];
2005        // when the number of colums is small, it is more efficient
2006        // to perform the product by iterating over the columns of
2007        // the rhs, otherwise iterating by rows can take advantage of
2008        // vectorized axpy.
2009        match (self.storage(), cols >= 8) {
2010            (CSR, true) => {
2011                let mut res = Array::zeros((rows, cols));
2012                prod::csr_mulacc_dense_rowmaj(
2013                    self.view(),
2014                    rhs.view(),
2015                    res.view_mut(),
2016                );
2017                res
2018            }
2019            (CSR, false) => {
2020                let mut res = Array::zeros((rows, cols).f());
2021                prod::csr_mulacc_dense_colmaj(
2022                    self.view(),
2023                    rhs.view(),
2024                    res.view_mut(),
2025                );
2026                res
2027            }
2028            (CSC, true) => {
2029                let mut res = Array::zeros((rows, cols));
2030                prod::csc_mulacc_dense_rowmaj(
2031                    self.view(),
2032                    rhs.view(),
2033                    res.view_mut(),
2034                );
2035                res
2036            }
2037            (CSC, false) => {
2038                let mut res = Array::zeros((rows, cols).f());
2039                prod::csc_mulacc_dense_colmaj(
2040                    self.view(),
2041                    rhs.view(),
2042                    res.view_mut(),
2043                );
2044                res
2045            }
2046        }
2047    }
2048}
2049
2050impl<N, I, IpS, IS, DS, DS2> Dot<CsMatBase<N, I, IpS, IS, DS>>
2051    for ArrayBase<DS2, Ix2>
2052where
2053    N: Clone + crate::MulAcc + num_traits::Zero + std::fmt::Debug,
2054    I: SpIndex,
2055    IpS: Deref<Target = [I]>,
2056    IS: Deref<Target = [I]>,
2057    DS: Deref<Target = [N]>,
2058    DS2: ndarray::Data<Elem = N>,
2059{
2060    type Output = Array<N, Ix2>;
2061
2062    fn dot(&self, rhs: &CsMatBase<N, I, IpS, IS, DS>) -> Array<N, Ix2> {
2063        let rhs_t = rhs.transpose_view();
2064        let lhs_t = self.t();
2065
2066        let rows = rhs_t.rows();
2067        let cols = lhs_t.ncols();
2068        // when the number of colums is small, it is more efficient
2069        // to perform the product by iterating over the columns of
2070        // the rhs, otherwise iterating by rows can take advantage of
2071        // vectorized axpy.
2072        let rres = match (rhs_t.storage(), cols >= 8) {
2073            (CSR, true) => {
2074                let mut res = Array::zeros((rows, cols));
2075                prod::csr_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut());
2076                res.reversed_axes()
2077            }
2078            (CSR, false) => {
2079                let mut res = Array::zeros((rows, cols).f());
2080                prod::csr_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut());
2081                res.reversed_axes()
2082            }
2083            (CSC, true) => {
2084                let mut res = Array::zeros((rows, cols));
2085                prod::csc_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut());
2086                res.reversed_axes()
2087            }
2088            (CSC, false) => {
2089                let mut res = Array::zeros((rows, cols).f());
2090                prod::csc_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut());
2091                res.reversed_axes()
2092            }
2093        };
2094
2095        assert_eq!(self.shape()[0], rres.shape()[0]);
2096        assert_eq!(rhs.cols(), rres.shape()[1]);
2097        rres
2098    }
2099}
2100
2101impl<N, I, Iptr, IpS, IS, DS, DS2> Dot<ArrayBase<DS2, Ix2>>
2102    for CsMatBase<N, I, IpS, IS, DS, Iptr>
2103where
2104    N: Clone + crate::MulAcc + num_traits::Zero,
2105    I: SpIndex,
2106    Iptr: SpIndex,
2107    IpS: Deref<Target = [Iptr]>,
2108    IS: Deref<Target = [I]>,
2109    DS: Deref<Target = [N]>,
2110    DS2: ndarray::Data<Elem = N>,
2111{
2112    type Output = Array<N, Ix2>;
2113
2114    fn dot(&self, rhs: &ArrayBase<DS2, Ix2>) -> Array<N, Ix2> {
2115        Mul::mul(self, rhs)
2116    }
2117}
2118
2119impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Mul<&'b ArrayBase<DS2, Ix1>>
2120    for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
2121where
2122    N: 'a + Clone + crate::MulAcc + num_traits::Zero,
2123    I: 'a + SpIndex,
2124    Iptr: 'a + SpIndex,
2125    IpS: 'a + Deref<Target = [Iptr]>,
2126    IS: 'a + Deref<Target = [I]>,
2127    DS: 'a + Deref<Target = [N]>,
2128    DS2: 'b + ndarray::Data<Elem = N>,
2129{
2130    type Output = Array<N, Ix1>;
2131
2132    fn mul(self, rhs: &'b ArrayBase<DS2, Ix1>) -> Array<N, Ix1> {
2133        let rows = self.rows();
2134        let cols = rhs.shape()[0];
2135        #[allow(deprecated)]
2136        let rhs_reshape = rhs.view().into_shape((cols, 1)).unwrap();
2137        let mut res = Array::zeros(rows);
2138        {
2139            #[allow(deprecated)]
2140            let res_reshape = res.view_mut().into_shape((rows, 1)).unwrap();
2141            match self.storage() {
2142                CSR => {
2143                    prod::csr_mulacc_dense_colmaj(
2144                        self.view(),
2145                        rhs_reshape,
2146                        res_reshape,
2147                    );
2148                }
2149                CSC => {
2150                    prod::csc_mulacc_dense_colmaj(
2151                        self.view(),
2152                        rhs_reshape,
2153                        res_reshape,
2154                    );
2155                }
2156            }
2157        }
2158        res
2159    }
2160}
2161
2162impl<N, I, Iptr, IpS, IS, DS, DS2> Dot<ArrayBase<DS2, Ix1>>
2163    for CsMatBase<N, I, IpS, IS, DS, Iptr>
2164where
2165    N: Clone + crate::MulAcc + num_traits::Zero,
2166    I: SpIndex,
2167    Iptr: SpIndex,
2168    IpS: Deref<Target = [Iptr]>,
2169    IS: Deref<Target = [I]>,
2170    DS: Deref<Target = [N]>,
2171    DS2: ndarray::Data<Elem = N>,
2172{
2173    type Output = Array<N, Ix1>;
2174
2175    fn dot(&self, rhs: &ArrayBase<DS2, Ix1>) -> Array<N, Ix1> {
2176        Mul::mul(self, rhs)
2177    }
2178}
2179
2180impl<N, I, Iptr, IpS, IS, DS> Index<[usize; 2]>
2181    for CsMatBase<N, I, IpS, IS, DS, Iptr>
2182where
2183    I: SpIndex,
2184    Iptr: SpIndex,
2185    IpS: Deref<Target = [Iptr]>,
2186    IS: Deref<Target = [I]>,
2187    DS: Deref<Target = [N]>,
2188{
2189    type Output = N;
2190
2191    fn index(&self, index: [usize; 2]) -> &N {
2192        let i = index[0];
2193        let j = index[1];
2194        self.get(i, j).unwrap()
2195    }
2196}
2197
2198impl<N, I, Iptr, IpS, IS, DS> IndexMut<[usize; 2]>
2199    for CsMatBase<N, I, IpS, IS, DS, Iptr>
2200where
2201    I: SpIndex,
2202    Iptr: SpIndex,
2203    IpS: Deref<Target = [Iptr]>,
2204    IS: Deref<Target = [I]>,
2205    DS: DerefMut<Target = [N]>,
2206{
2207    fn index_mut(&mut self, index: [usize; 2]) -> &mut N {
2208        let i = index[0];
2209        let j = index[1];
2210        self.get_mut(i, j).unwrap()
2211    }
2212}
2213
2214impl<N, I, Iptr, IpS, IS, DS> Index<NnzIndex>
2215    for CsMatBase<N, I, IpS, IS, DS, Iptr>
2216where
2217    I: SpIndex,
2218    Iptr: SpIndex,
2219    IpS: Deref<Target = [Iptr]>,
2220    IS: Deref<Target = [I]>,
2221    DS: Deref<Target = [N]>,
2222{
2223    type Output = N;
2224
2225    fn index(&self, index: NnzIndex) -> &N {
2226        let NnzIndex(i) = index;
2227        self.data().get(i).unwrap()
2228    }
2229}
2230
2231impl<N, I, Iptr, IpS, IS, DS> IndexMut<NnzIndex>
2232    for CsMatBase<N, I, IpS, IS, DS, Iptr>
2233where
2234    I: SpIndex,
2235    Iptr: SpIndex,
2236    IpS: Deref<Target = [Iptr]>,
2237    IS: Deref<Target = [I]>,
2238    DS: DerefMut<Target = [N]>,
2239{
2240    fn index_mut(&mut self, index: NnzIndex) -> &mut N {
2241        let NnzIndex(i) = index;
2242        self.data_mut().get_mut(i).unwrap()
2243    }
2244}
2245
2246impl<N, I, Iptr, IpS, IS, DS> SparseMat for CsMatBase<N, I, IpS, IS, DS, Iptr>
2247where
2248    I: SpIndex,
2249    Iptr: SpIndex,
2250    IpS: Deref<Target = [Iptr]>,
2251    IS: Deref<Target = [I]>,
2252    DS: Deref<Target = [N]>,
2253{
2254    fn rows(&self) -> usize {
2255        self.rows()
2256    }
2257
2258    fn cols(&self) -> usize {
2259        self.cols()
2260    }
2261
2262    fn nnz(&self) -> usize {
2263        self.nnz()
2264    }
2265}
2266
2267impl<'a, N, I, Iptr, IpS, IS, DS> SparseMat
2268    for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
2269where
2270    I: 'a + SpIndex,
2271    Iptr: 'a + SpIndex,
2272    N: 'a,
2273    IpS: Deref<Target = [Iptr]>,
2274    IS: Deref<Target = [I]>,
2275    DS: Deref<Target = [N]>,
2276{
2277    fn rows(&self) -> usize {
2278        (*self).rows()
2279    }
2280
2281    fn cols(&self) -> usize {
2282        (*self).cols()
2283    }
2284
2285    fn nnz(&self) -> usize {
2286        (*self).nnz()
2287    }
2288}
2289
2290impl<'a, N, I, IpS, IS, DS, Iptr> IntoIterator
2291    for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
2292where
2293    I: 'a + SpIndex,
2294    Iptr: 'a + SpIndex,
2295    N: 'a,
2296    IpS: Deref<Target = [Iptr]>,
2297    IS: Deref<Target = [I]>,
2298    DS: Deref<Target = [N]>,
2299{
2300    type Item = (&'a N, (I, I));
2301    type IntoIter = CsIter<'a, N, I, Iptr>;
2302    fn into_iter(self) -> Self::IntoIter {
2303        self.iter()
2304    }
2305}
2306
2307impl<'a, N, I, Iptr> IntoIterator for CsMatViewI<'a, N, I, Iptr>
2308where
2309    I: 'a + SpIndex,
2310    Iptr: 'a + SpIndex,
2311    N: 'a,
2312{
2313    type Item = (&'a N, (I, I));
2314    type IntoIter = CsIter<'a, N, I, Iptr>;
2315    fn into_iter(self) -> Self::IntoIter {
2316        self.iter_rbr()
2317    }
2318}
2319
2320#[cfg(test)]
2321mod test {
2322    use super::CompressedStorage::CSR;
2323    use crate::errors::StructureErrorKind;
2324    use crate::sparse::{CsMat, CsMatI, CsMatView, CsVec};
2325    use crate::test_data::{mat1, mat1_csc, mat1_times_2};
2326    use ndarray::{arr2, Array};
2327
2328    #[test]
2329    fn test_copy() {
2330        let m = mat1();
2331        let view1 = m.view();
2332        let view2 = view1; // this shouldn't move
2333        assert_eq!(view1, view2);
2334    }
2335
2336    #[test]
2337    fn test_new_csr_success() {
2338        let indptr_ok: &[usize] = &[0, 1, 2, 3];
2339        let indices_ok: &[usize] = &[0, 1, 2];
2340        let data_ok: &[f64] = &[1., 1., 1.];
2341        let m = CsMatView::try_new((3, 3), indptr_ok, indices_ok, data_ok);
2342        assert!(m.is_ok());
2343    }
2344
2345    #[test]
2346    #[should_panic]
2347    fn test_new_csr_bad_indptr_length() {
2348        let indptr_fail1: &[usize] = &[0, 1, 2];
2349        let indices_ok: &[usize] = &[0, 1, 2];
2350        let data_ok: &[f64] = &[1., 1., 1.];
2351        let res = CsMatView::try_new((3, 3), indptr_fail1, indices_ok, data_ok);
2352        res.unwrap(); // unreachable
2353    }
2354
2355    #[test]
2356    #[should_panic]
2357    fn test_new_csr_out_of_bounds_index() {
2358        let indptr_ok: &[usize] = &[0, 1, 2, 3];
2359        let data_ok: &[f64] = &[1., 1., 1.];
2360        let indices_fail2: &[usize] = &[0, 1, 4];
2361        let res = CsMatView::try_new((3, 3), indptr_ok, indices_fail2, data_ok);
2362        res.unwrap(); //unreachable
2363    }
2364
2365    #[test]
2366    #[should_panic]
2367    fn test_new_csr_bad_nnz_count() {
2368        let indices_ok: &[usize] = &[0, 1, 2];
2369        let data_ok: &[f64] = &[1., 1., 1.];
2370        let indptr_fail2: &[usize] = &[0, 1, 2, 4];
2371        let res = CsMatView::try_new((3, 3), indptr_fail2, indices_ok, data_ok);
2372        res.unwrap(); //unreachable
2373    }
2374
2375    #[test]
2376    #[should_panic]
2377    fn test_new_csr_data_indices_mismatch1() {
2378        let indptr_ok: &[usize] = &[0, 1, 2, 3];
2379        let data_ok: &[f64] = &[1., 1., 1.];
2380        let indices_fail1: &[usize] = &[0, 1];
2381        let res = CsMatView::try_new((3, 3), indptr_ok, indices_fail1, data_ok);
2382        res.unwrap(); //unreachable
2383    }
2384
2385    #[test]
2386    #[should_panic]
2387    fn test_new_csr_data_indices_mismatch2() {
2388        let indptr_ok: &[usize] = &[0, 1, 2, 3];
2389        let indices_ok: &[usize] = &[0, 1, 2];
2390        let data_fail1: &[f64] = &[1., 1., 1., 1.];
2391        let res = CsMatView::try_new((3, 3), indptr_ok, indices_ok, data_fail1);
2392        res.unwrap(); //unreachable
2393    }
2394
2395    #[test]
2396    #[should_panic]
2397    fn test_new_csr_data_indices_mismatch3() {
2398        let indptr_ok: &[usize] = &[0, 1, 2, 3];
2399        let indices_ok: &[usize] = &[0, 1, 2];
2400        let data_fail2: &[f64] = &[1., 1.];
2401        let res = CsMatView::try_new((3, 3), indptr_ok, indices_ok, data_fail2);
2402        res.unwrap(); //unreachable
2403    }
2404
2405    #[test]
2406    fn test_new_csr_fails() {
2407        let indices_ok: &[usize] = &[0, 1, 2];
2408        let data_ok: &[f64] = &[1., 1., 1.];
2409        let indptr_fail3: &[usize] = &[0, 2, 1, 3];
2410        assert_eq!(
2411            CsMatView::try_new((3, 3), indptr_fail3, indices_ok, data_ok)
2412                .unwrap_err()
2413                .3
2414                .kind(),
2415            StructureErrorKind::Unsorted
2416        );
2417    }
2418
2419    #[test]
2420    fn test_new_csr_fail_indices_ordering() {
2421        let indptr: &[usize] = &[0, 2, 4, 5, 6, 7];
2422        // good indices would be [2, 3, 3, 4, 2, 1, 3];
2423        let indices: &[usize] = &[3, 2, 3, 4, 2, 1, 3];
2424        let data: &[f64] = &[
2425            0.35310881, 0.42380633, 0.28035896, 0.58082095, 0.53350123,
2426            0.88132896, 0.72527863,
2427        ];
2428        assert_eq!(
2429            CsMatView::try_new((5, 5), indptr, indices, data)
2430                .unwrap_err()
2431                .3
2432                .kind(),
2433            StructureErrorKind::Unsorted
2434        );
2435    }
2436
2437    #[test]
2438    fn test_new_csr_csc_success() {
2439        let indptr_ok: &[usize] = &[0, 2, 5, 6];
2440        let indices_ok: &[usize] = &[2, 3, 1, 2, 3, 3];
2441        let data_ok: &[f64] = &[
2442            0.05734571, 0.15543348, 0.75628258, 0.83054515, 0.71851547,
2443            0.46202352,
2444        ];
2445        assert!(
2446            CsMatView::try_new((3, 4), indptr_ok, indices_ok, data_ok).is_ok()
2447        );
2448        assert!(
2449            CsMatView::try_new_csc((4, 3), indptr_ok, indices_ok, data_ok)
2450                .is_ok()
2451        );
2452    }
2453
2454    #[test]
2455    #[should_panic]
2456    fn test_new_csc_bad_indptr_length() {
2457        let indptr_ok: &[usize] = &[0, 2, 5, 6];
2458        let indices_ok: &[usize] = &[2, 3, 1, 2, 3, 3];
2459        let data_ok: &[f64] = &[
2460            0.05734571, 0.15543348, 0.75628258, 0.83054515, 0.71851547,
2461            0.46202352,
2462        ];
2463        let res =
2464            CsMatView::try_new_csc((3, 4), indptr_ok, indices_ok, data_ok);
2465        res.unwrap(); //unreachable
2466    }
2467
2468    #[test]
2469    fn test_new_csr_vec_borrowed() {
2470        let indptr_ok = vec![0, 1, 2, 3];
2471        let indices_ok = vec![0, 1, 2];
2472        let data_ok: Vec<f64> = vec![1., 1., 1.];
2473        assert!(
2474            CsMatView::try_new((3, 3), &indptr_ok, &indices_ok, &data_ok)
2475                .is_ok()
2476        );
2477    }
2478
2479    #[test]
2480    fn test_new_csr_vec_owned() {
2481        let indptr_ok = vec![0, 1, 2, 3];
2482        let indices_ok = vec![0, 1, 2];
2483        let data_ok: Vec<f64> = vec![1., 1., 1.];
2484        assert!(CsMat::new_from_unsorted(
2485            (3, 3),
2486            indptr_ok,
2487            indices_ok,
2488            data_ok
2489        )
2490        .is_ok());
2491    }
2492
2493    #[test]
2494    fn test_csr_from_dense() {
2495        let m = Array::eye(3);
2496        let m_sparse = CsMat::csr_from_dense(m.view(), 0.);
2497
2498        assert_eq!(m_sparse, CsMat::eye(3));
2499
2500        let m = arr2(&[
2501            [1., 0., 2., 1e-7, 1.],
2502            [0., 0., 0., 1., 0.],
2503            [3., 0., 1., 0., 0.],
2504        ]);
2505        let m_sparse = CsMat::csr_from_dense(m.view(), 1e-5);
2506
2507        let expected_output = CsMat::new(
2508            (3, 5),
2509            vec![0, 3, 4, 6],
2510            vec![0, 2, 4, 3, 0, 2],
2511            vec![1., 2., 1., 1., 3., 1.],
2512        );
2513
2514        assert_eq!(m_sparse, expected_output);
2515    }
2516
2517    #[test]
2518    fn test_csc_from_dense() {
2519        let m = Array::eye(3);
2520        let m_sparse = CsMat::csc_from_dense(m.view(), 0.);
2521
2522        assert_eq!(m_sparse, CsMat::eye_csc(3));
2523
2524        let m = arr2(&[
2525            [1., 0., 2., 1e-7, 1.],
2526            [0., 0., 0., 1., 0.],
2527            [3., 0., 1., 0., 0.],
2528        ]);
2529        let m_sparse = CsMat::csc_from_dense(m.view(), 1e-5);
2530
2531        let expected_output = CsMat::new_csc(
2532            (3, 5),
2533            vec![0, 2, 2, 4, 5, 6],
2534            vec![0, 2, 0, 2, 1, 0],
2535            vec![1., 3., 2., 1., 1., 1.],
2536        );
2537
2538        assert_eq!(m_sparse, expected_output);
2539    }
2540
2541    #[test]
2542    fn owned_csr_unsorted_indices() {
2543        let indptr = vec![0, 3, 3, 5, 6, 7];
2544        let indices_sorted = &[1, 2, 3, 2, 3, 4, 4];
2545        let indices_shuffled = vec![1, 3, 2, 2, 3, 4, 4];
2546        let mut data: Vec<i32> = (0..7).collect();
2547        let m = CsMat::new_from_unsorted(
2548            (5, 5),
2549            indptr,
2550            indices_shuffled,
2551            data.clone(),
2552        )
2553        .unwrap();
2554        assert_eq!(m.indices(), indices_sorted);
2555        data.swap(1, 2);
2556        assert_eq!(m.data(), &data[..]);
2557    }
2558
2559    #[test]
2560    fn new_csr_with_empty_row() {
2561        let indptr: &[usize] = &[0, 3, 3, 5, 6, 7];
2562        let indices: &[usize] = &[1, 2, 3, 2, 3, 4, 4];
2563        let data: &[f64] = &[
2564            0.75672424, 0.1649078, 0.30140296, 0.10358244, 0.6283315,
2565            0.39244208, 0.57202407,
2566        ];
2567        assert!(CsMatView::try_new((5, 5), indptr, indices, data).is_ok());
2568    }
2569
2570    #[test]
2571    fn csr_to_csc() {
2572        let a = mat1();
2573        let a_csc_ground_truth = mat1_csc();
2574        let a_csc = a.to_other_storage();
2575        assert_eq!(a_csc, a_csc_ground_truth);
2576    }
2577
2578    #[test]
2579    fn test_self_smul() {
2580        let mut a = mat1();
2581        a.scale(2.);
2582        let c_true = mat1_times_2();
2583        assert_eq!(a.indptr(), c_true.indptr());
2584        assert_eq!(a.indices(), c_true.indices());
2585        assert_eq!(a.data(), c_true.data());
2586    }
2587
2588    #[test]
2589    fn outer_block_iter() {
2590        let mat: CsMat<f64> = CsMat::eye(11);
2591        let mut block_iter = mat.outer_block_iter(3);
2592        assert_eq!(block_iter.next().unwrap().rows(), 3);
2593        assert_eq!(block_iter.next().unwrap().rows(), 3);
2594        assert_eq!(block_iter.next().unwrap().rows(), 3);
2595        assert_eq!(block_iter.next().unwrap().rows(), 2);
2596        assert_eq!(block_iter.next(), None);
2597
2598        let mut block_iter = mat.outer_block_iter(4);
2599        assert_eq!(block_iter.next().unwrap().cols(), 11);
2600        block_iter.next().unwrap();
2601        block_iter.next().unwrap();
2602        assert_eq!(block_iter.next(), None);
2603    }
2604
2605    #[test]
2606    fn middle_outer_views() {
2607        let size = 11;
2608        let csr: CsMat<f64> = CsMat::eye(size);
2609        #[allow(deprecated)]
2610        let v = csr.view().middle_outer_views(1, 3);
2611        assert_eq!(v.shape(), (3, size));
2612        assert_eq!(v.nnz(), 3);
2613
2614        let csc = csr.to_other_storage();
2615        #[allow(deprecated)]
2616        let v = csc.view().middle_outer_views(1, 3);
2617        assert_eq!(v.shape(), (size, 3));
2618        assert_eq!(v.nnz(), 3);
2619    }
2620
2621    #[test]
2622    fn nnz_index() {
2623        let mat: CsMat<f64> = CsMat::eye(11);
2624
2625        assert_eq!(mat.nnz_index(2, 3), None);
2626        assert_eq!(mat.nnz_index(5, 7), None);
2627        assert_eq!(mat.nnz_index(0, 11), None);
2628        assert_eq!(mat.nnz_index(0, 0), Some(super::NnzIndex(0)));
2629        assert_eq!(mat.nnz_index(7, 7), Some(super::NnzIndex(7)));
2630        assert_eq!(mat.nnz_index(10, 10), Some(super::NnzIndex(10)));
2631
2632        let index = mat.nnz_index(8, 8).unwrap();
2633        assert_eq!(mat[index], 1.);
2634        let mut mat = mat;
2635        mat[index] = 2.;
2636        assert_eq!(mat[index], 2.);
2637    }
2638
2639    #[test]
2640    fn index() {
2641        // | 0 2 0 |
2642        // | 1 0 0 |
2643        // | 0 3 4 |
2644        let mat = CsMat::new_csc(
2645            (3, 3),
2646            vec![0, 1, 3, 4],
2647            vec![1, 0, 2, 2],
2648            vec![1., 2., 3., 4.],
2649        );
2650        assert_eq!(mat[[1, 0]], 1.);
2651        assert_eq!(mat[[0, 1]], 2.);
2652        assert_eq!(mat[[2, 1]], 3.);
2653        assert_eq!(mat[[2, 2]], 4.);
2654        assert_eq!(mat.get(0, 0), None);
2655        assert_eq!(mat.get(4, 4), None);
2656    }
2657
2658    #[test]
2659    fn get_mut() {
2660        // | 0 1 0 |
2661        // | 1 0 0 |
2662        // | 0 1 1 |
2663        let mut mat = CsMat::new_csc(
2664            (3, 3),
2665            vec![0, 1, 3, 4],
2666            vec![1, 0, 2, 2],
2667            vec![1.; 4],
2668        );
2669
2670        *mat.get_mut(2, 1).unwrap() = 3.;
2671
2672        let exp = CsMat::new_csc(
2673            (3, 3),
2674            vec![0, 1, 3, 4],
2675            vec![1, 0, 2, 2],
2676            vec![1., 1., 3., 1.],
2677        );
2678
2679        assert_eq!(mat, exp);
2680
2681        mat[[2, 2]] = 5.;
2682        let exp = CsMat::new_csc(
2683            (3, 3),
2684            vec![0, 1, 3, 4],
2685            vec![1, 0, 2, 2],
2686            vec![1., 1., 3., 5.],
2687        );
2688
2689        assert_eq!(mat, exp);
2690    }
2691
2692    #[test]
2693    fn map() {
2694        // | 0 1 0 |
2695        // | 1 0 0 |
2696        // | 0 1 1 |
2697        let mat = CsMat::new_csc(
2698            (3, 3),
2699            vec![0, 1, 3, 4],
2700            vec![1, 0, 2, 2],
2701            vec![1.; 4],
2702        );
2703
2704        let mut res = mat.map(|&x| x + 2.);
2705        let expected = CsMat::new_csc(
2706            (3, 3),
2707            vec![0, 1, 3, 4],
2708            vec![1, 0, 2, 2],
2709            vec![3.; 4],
2710        );
2711        assert_eq!(res, expected);
2712
2713        res.map_inplace(|&x| x / 3.);
2714        assert_eq!(res, mat);
2715    }
2716
2717    #[test]
2718    fn insert() {
2719        // | 0 1 0 |
2720        // | 1 0 0 |
2721        // | 0 1 1 |
2722        let mut mat = CsMat::empty(CSR, 0);
2723        mat.reserve_outer_dim(3);
2724        mat.reserve_nnz(4);
2725        // exercise the fast and easy path where the elements are added
2726        // in row order for a CSR matrix
2727        mat.insert(0, 1, 1.);
2728        mat.insert(1, 0, 1.);
2729        mat.insert(2, 1, 1.);
2730        mat.insert(2, 2, 1.);
2731
2732        let expected =
2733            CsMat::new((3, 3), vec![0, 1, 2, 4], vec![1, 0, 1, 2], vec![1.; 4]);
2734        assert_eq!(mat, expected);
2735
2736        // | 2 1 0 |
2737        // | 1 0 0 |
2738        // | 0 1 1 |
2739        // exercise adding inside an already formed row (ie a search needs
2740        // to be performed)
2741        mat.insert(0, 0, 2.);
2742        let expected = CsMat::new(
2743            (3, 3),
2744            vec![0, 2, 3, 5],
2745            vec![0, 1, 0, 1, 2],
2746            vec![2., 1., 1., 1., 1.],
2747        );
2748        assert_eq!(mat, expected);
2749
2750        // | 2 1 0 |
2751        // | 3 0 0 |
2752        // | 0 1 1 |
2753        // exercise the fact that inserting in an existing element
2754        // should change this element's value
2755        mat.insert(1, 0, 3.);
2756        let expected = CsMat::new(
2757            (3, 3),
2758            vec![0, 2, 3, 5],
2759            vec![0, 1, 0, 1, 2],
2760            vec![2., 1., 3., 1., 1.],
2761        );
2762        assert_eq!(mat, expected);
2763    }
2764
2765    #[test]
2766    /// Non-regression test for https://github.com/vbarrielle/sprs/issues/129
2767    fn bug_129() {
2768        let mut mat = CsMat::zero((3, 100));
2769        mat.insert(2, 3, 42);
2770        let mut iter = mat.iter();
2771        assert_eq!(iter.next(), Some((&42, (2, 3))));
2772        assert_eq!(iter.next(), None);
2773    }
2774
2775    #[test]
2776    fn iter_mut() {
2777        // | 0 1 0 |
2778        // | 1 0 0 |
2779        // | 0 1 1 |
2780        let mut mat = CsMat::new_csc(
2781            (3, 3),
2782            vec![0, 1, 3, 4],
2783            vec![1, 0, 2, 2],
2784            vec![1.; 4],
2785        );
2786
2787        for mut col_vec in mat.outer_iterator_mut() {
2788            for (row_ind, val) in col_vec.iter_mut() {
2789                *val = row_ind as f64 + 1.;
2790            }
2791        }
2792
2793        let expected = CsMat::new_csc(
2794            (3, 3),
2795            vec![0, 1, 3, 4],
2796            vec![1, 0, 2, 2],
2797            vec![2., 1., 3., 3.],
2798        );
2799        assert_eq!(mat, expected);
2800    }
2801
2802    #[test]
2803    #[should_panic]
2804    fn modify_fail() {
2805        let mut mat = CsMat::new_csc(
2806            (3, 3),
2807            vec![0, 1, 3, 4],
2808            vec![1, 0, 2, 2],
2809            vec![1.; 4],
2810        );
2811
2812        // we panic because we forget to modify the last index, which gets
2813        // pushed in the same col as its predecessor, yet has the same value
2814        mat.modify(|indptr, indices, data| {
2815            indptr[1] = 2;
2816            indptr[2] = 4;
2817            indices[0] = 0;
2818            indices[1] = 1;
2819            data[2] = 2.;
2820        });
2821    }
2822
2823    #[test]
2824    fn convert_types() {
2825        let mat: CsMat<f32> = CsMat::eye(3);
2826        let mat_: CsMatI<f64, u32> = mat.to_other_types();
2827        assert_eq!(mat_.indptr(), &[0, 1, 2, 3][..]);
2828
2829        let mat = CsMatI::new_csc(
2830            (3, 3),
2831            vec![0u32, 1, 3, 4],
2832            vec![1, 0, 2, 2],
2833            vec![1.; 4],
2834        );
2835        let mat_: CsMatI<f32, usize, u32> = mat.to_other_types();
2836        assert_eq!(mat_.indptr(), &[0, 1, 3, 4][..]);
2837        assert_eq!(mat_.data(), &[1.0f32, 1., 1., 1.]);
2838    }
2839
2840    #[test]
2841    fn iter() {
2842        let mat = CsMat::new_csc(
2843            (3, 3),
2844            vec![0, 1, 3, 4],
2845            vec![1, 0, 2, 2],
2846            vec![1.; 4],
2847        );
2848        let mut iter = mat.iter();
2849        assert_eq!(iter.next(), Some((&1., (1, 0))));
2850        assert_eq!(iter.next(), Some((&1., (0, 1))));
2851        assert_eq!(iter.next(), Some((&1., (2, 1))));
2852        assert_eq!(iter.next(), Some((&1., (2, 2))));
2853        assert_eq!(iter.next(), None);
2854    }
2855
2856    #[test]
2857    fn degrees() {
2858        // | 1 0 0 3 1 |
2859        // | 0 2 0 0 0 |
2860        // | 0 0 0 1 0 |
2861        // | 3 0 1 1 0 |
2862        // | 1 0 0 0 1 |
2863        let mat = CsMat::new_csc(
2864            (5, 5),
2865            vec![0, 3, 4, 5, 8, 10],
2866            vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4],
2867            vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1],
2868        );
2869
2870        let degrees = mat.degrees();
2871        assert_eq!(&degrees, &[2, 0, 1, 2, 1],);
2872    }
2873
2874    #[test]
2875    fn diag() {
2876        // | 1 0 0 3 1 |
2877        // | 0 2 0 0 0 |
2878        // | 0 0 0 1 0 |
2879        // | 3 0 1 1 0 |
2880        // | 1 0 0 0 1 |
2881        let mat = CsMat::new_csc(
2882            (5, 5),
2883            vec![0, 3, 4, 5, 8, 10],
2884            vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4],
2885            vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1],
2886        );
2887
2888        let diag = mat.diag();
2889        let expected = CsVec::new(5, vec![0, 1, 3, 4], vec![1, 2, 1, 1]);
2890        assert_eq!(diag, expected);
2891
2892        let mut iter = mat.diag_iter();
2893        assert_eq!(iter.next().unwrap(), Some(&1));
2894        assert_eq!(iter.next().unwrap(), Some(&2));
2895        assert_eq!(iter.next().unwrap(), None);
2896        assert_eq!(iter.next().unwrap(), Some(&1));
2897        assert_eq!(iter.next().unwrap(), Some(&1));
2898        assert_eq!(iter.next(), None);
2899    }
2900
2901    #[test]
2902    #[cfg_attr(miri, ignore)]
2903    fn diag_mut() {
2904        // | 1 0 0 3 1 |
2905        // | 0 2 0 0 0 |
2906        // | 0 0 0 1 0 |
2907        // | 3 0 1 1 0 |
2908        // | 1 0 0 0 1 |
2909        let mut mat = CsMat::new_csc(
2910            (5, 5),
2911            vec![0, 3, 4, 5, 8, 10],
2912            vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4],
2913            vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1],
2914        );
2915
2916        let mut diags = mat.diag_iter_mut().collect::<Vec<_>>();
2917        diags[4].as_mut().map(|x| **x *= 3);
2918        diags[3].as_mut().map(|x| **x -= 4);
2919        let expected = CsVec::new(5, vec![0, 1, 3, 4], vec![1, 2, -3, 3]);
2920        assert_eq!(mat.diag(), expected);
2921    }
2922
2923    #[test]
2924    fn diag_rectangular() {
2925        // | 1 0 0 3 1 3|
2926        // | 0 2 0 0 0 0|
2927        // | 0 0 0 1 0 1|
2928        // | 3 0 1 1 0 0|
2929        // | 1 0 0 0 1 0|
2930        let mat = CsMat::new_csc(
2931            (5, 6),
2932            vec![0, 3, 4, 5, 8, 10, 12],
2933            vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4, 0, 2],
2934            vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1, 3, 1],
2935        );
2936
2937        let diag = mat.diag();
2938        let expected = CsVec::new(5, vec![0, 1, 3, 4], vec![1, 2, 1, 1]);
2939        assert_eq!(diag, expected);
2940
2941        let mut iter = mat.diag_iter();
2942        assert_eq!(iter.next().unwrap(), Some(&1));
2943        assert_eq!(iter.next().unwrap(), Some(&2));
2944        assert_eq!(iter.next().unwrap(), None);
2945        assert_eq!(iter.next().unwrap(), Some(&1));
2946        assert_eq!(iter.next().unwrap(), Some(&1));
2947        assert_eq!(iter.next(), None);
2948    }
2949
2950    #[test]
2951    fn onehot_zero() {
2952        let onehot: CsMat<f32> = CsMat::zero((3, 3)).to_inner_onehot();
2953
2954        assert!(onehot.is_csr());
2955        assert_eq!(CsMat::zero((3, 3)), onehot);
2956    }
2957
2958    #[test]
2959    fn onehot_eye() {
2960        let mat = CsMat::new(
2961            (2, 2),
2962            vec![0, 2, 4],
2963            vec![0, 1, 0, 1],
2964            vec![2.0, 0.0, 0.0, 2.0],
2965        );
2966
2967        let onehot = mat.to_inner_onehot();
2968
2969        assert!(onehot.is_csr());
2970        assert_eq!(CsMat::eye(2), onehot);
2971    }
2972
2973    #[test]
2974    fn onehot_sparse_csc() {
2975        let mat = CsMat::new_csc((2, 3), vec![0, 0, 1, 1], vec![1], vec![2.0]);
2976
2977        let onehot = mat.to_inner_onehot();
2978
2979        let expected =
2980            CsMat::new_csc((2, 3), vec![0, 0, 1, 1], vec![1], vec![1.0]);
2981
2982        assert!(onehot.is_csc());
2983        assert_eq!(expected, onehot);
2984    }
2985
2986    #[test]
2987    fn onehot_ignores_nan() {
2988        let mat = CsMat::new(
2989            (2, 2),
2990            vec![0, 2, 3],
2991            vec![0, 1, 1],
2992            vec![2.0, std::f64::NAN, 2.0],
2993        );
2994
2995        let onehot = mat.to_inner_onehot();
2996
2997        assert!(onehot.is_csr());
2998        assert_eq!(CsMat::eye(2), onehot);
2999    }
3000
3001    #[test]
3002    fn mul_assign() {
3003        let mut m1 = crate::TriMat::new((6, 9));
3004        m1.add_triplet(1, 1, 8_i32);
3005        m1.add_triplet(1, 2, 7);
3006        m1.add_triplet(0, 1, 6);
3007        m1.add_triplet(0, 8, 5);
3008        m1.add_triplet(4, 2, 4);
3009        let mut m1: CsMat<_> = m1.to_csr();
3010
3011        m1 *= 2;
3012        for (&v, (j, i)) in m1.iter() {
3013            match (j, i) {
3014                (1, 1) => assert_eq!(v, 16),
3015                (1, 2) => assert_eq!(v, 14),
3016                (0, 1) => assert_eq!(v, 12),
3017                (0, 8) => assert_eq!(v, 10),
3018                (4, 2) => assert_eq!(v, 8),
3019                _ => panic!(),
3020            }
3021        }
3022    }
3023
3024    #[test]
3025    fn div_assign() {
3026        let mut m1 = crate::TriMat::new((6, 9));
3027        m1.add_triplet(1, 1, 8_i32);
3028        m1.add_triplet(1, 2, 7);
3029        m1.add_triplet(0, 1, 6);
3030        m1.add_triplet(0, 8, 5);
3031        m1.add_triplet(4, 2, 4);
3032        let mut m1: CsMat<_> = m1.to_csr();
3033
3034        m1 /= 2;
3035        for (&v, (j, i)) in m1.iter() {
3036            match (j, i) {
3037                (1, 1) => assert_eq!(v, 4),
3038                (1, 2) => assert_eq!(v, 3),
3039                (0, 1) => assert_eq!(v, 3),
3040                (0, 8) => assert_eq!(v, 2),
3041                (4, 2) => assert_eq!(v, 2),
3042                _ => panic!(),
3043            }
3044        }
3045    }
3046
3047    #[test]
3048    fn issue_99() {
3049        let a = crate::TriMat::<i32>::new((10, 1)).to_csc::<usize>();
3050        let b = crate::TriMat::<i32>::new((1, 9)).to_csr();
3051        let _c = &a * &b;
3052    }
3053}
3054
3055#[cfg(feature = "approx")]
3056mod approx_impls {
3057    use super::*;
3058    use approx::*;
3059
3060    impl<N, I, Iptr, IS1, DS1, ISptr1, IS2, ISptr2, DS2>
3061        AbsDiffEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>
3062        for CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>
3063    where
3064        I: SpIndex,
3065        Iptr: SpIndex,
3066        CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>:
3067            std::cmp::PartialEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>,
3068        IS1: Deref<Target = [I]>,
3069        IS2: Deref<Target = [I]>,
3070        ISptr1: Deref<Target = [Iptr]>,
3071        ISptr2: Deref<Target = [Iptr]>,
3072        DS1: Deref<Target = [N]>,
3073        DS2: Deref<Target = [N]>,
3074        N: AbsDiffEq,
3075        N::Epsilon: Clone,
3076        N: num_traits::Zero,
3077    {
3078        type Epsilon = N::Epsilon;
3079        fn default_epsilon() -> N::Epsilon {
3080            N::default_epsilon()
3081        }
3082        fn abs_diff_eq(
3083            &self,
3084            other: &CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>,
3085            epsilon: N::Epsilon,
3086        ) -> bool {
3087            if self.shape() != other.shape() {
3088                return false;
3089            }
3090            if self.storage() == other.storage() {
3091                self.outer_iterator()
3092                    .zip(other.outer_iterator())
3093                    .all(|(r1, r2)| r1.abs_diff_eq(&r2, epsilon.clone()))
3094            } else {
3095                // Checks if all elements in self has a matching element
3096                // in other
3097                let all_matching = self.iter().all(|(n, (i, j))| {
3098                    n.abs_diff_eq(
3099                        other
3100                            .get(i.to_usize().unwrap(), j.to_usize().unwrap())
3101                            .unwrap_or(&N::zero()),
3102                        epsilon.clone(),
3103                    )
3104                });
3105                if !all_matching {
3106                    return false;
3107                }
3108
3109                // Must also check if all elements in other matches self
3110                other.iter().all(|(n, (i, j))| {
3111                    n.abs_diff_eq(
3112                        self.get(i.to_usize().unwrap(), j.to_usize().unwrap())
3113                            .unwrap_or(&N::zero()),
3114                        epsilon.clone(),
3115                    )
3116                })
3117            }
3118        }
3119    }
3120    impl<N, I, Iptr, IS1, DS1, ISptr1, IS2, ISptr2, DS2>
3121        UlpsEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>
3122        for CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>
3123    where
3124        I: SpIndex,
3125        Iptr: SpIndex,
3126        CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>:
3127            std::cmp::PartialEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>,
3128        IS1: Deref<Target = [I]>,
3129        IS2: Deref<Target = [I]>,
3130        ISptr1: Deref<Target = [Iptr]>,
3131        ISptr2: Deref<Target = [Iptr]>,
3132        DS1: Deref<Target = [N]>,
3133        DS2: Deref<Target = [N]>,
3134        N: UlpsEq,
3135        N::Epsilon: Clone,
3136        N: num_traits::Zero,
3137    {
3138        fn default_max_ulps() -> u32 {
3139            N::default_max_ulps()
3140        }
3141        fn ulps_eq(
3142            &self,
3143            other: &CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>,
3144            epsilon: N::Epsilon,
3145            max_ulps: u32,
3146        ) -> bool {
3147            if self.shape() != other.shape() {
3148                return false;
3149            }
3150            if self.storage() == other.storage() {
3151                self.outer_iterator()
3152                    .zip(other.outer_iterator())
3153                    .all(|(r1, r2)| r1.ulps_eq(&r2, epsilon.clone(), max_ulps))
3154            } else {
3155                // Checks if all elements in self has a matching element
3156                // in other
3157                let all_matches = self.iter().all(|(n, (i, j))| {
3158                    n.ulps_eq(
3159                        other
3160                            .get(i.to_usize().unwrap(), j.to_usize().unwrap())
3161                            .unwrap_or(&N::zero()),
3162                        epsilon.clone(),
3163                        max_ulps,
3164                    )
3165                });
3166                if !all_matches {
3167                    return false;
3168                }
3169
3170                // Must also check if all elements in other matches self
3171                other.iter().all(|(n, (i, j))| {
3172                    n.ulps_eq(
3173                        self.get(i.to_usize().unwrap(), j.to_usize().unwrap())
3174                            .unwrap_or(&N::zero()),
3175                        epsilon.clone(),
3176                        max_ulps,
3177                    )
3178                })
3179            }
3180        }
3181    }
3182    impl<N, I, Iptr, IS1, DS1, ISptr1, IS2, ISptr2, DS2>
3183        RelativeEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>
3184        for CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>
3185    where
3186        I: SpIndex,
3187        Iptr: SpIndex,
3188        CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>:
3189            std::cmp::PartialEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>,
3190        IS1: Deref<Target = [I]>,
3191        IS2: Deref<Target = [I]>,
3192        ISptr1: Deref<Target = [Iptr]>,
3193        ISptr2: Deref<Target = [Iptr]>,
3194        DS1: Deref<Target = [N]>,
3195        DS2: Deref<Target = [N]>,
3196        N: RelativeEq,
3197        N::Epsilon: Clone,
3198        N: num_traits::Zero,
3199    {
3200        fn default_max_relative() -> N::Epsilon {
3201            N::default_max_relative()
3202        }
3203        fn relative_eq(
3204            &self,
3205            other: &CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>,
3206            epsilon: N::Epsilon,
3207            max_relative: Self::Epsilon,
3208        ) -> bool {
3209            if self.shape() != other.shape() {
3210                return false;
3211            }
3212            if self.storage() == other.storage() {
3213                self.outer_iterator().zip(other.outer_iterator()).all(
3214                    |(r1, r2)| {
3215                        r1.relative_eq(
3216                            &r2,
3217                            epsilon.clone(),
3218                            max_relative.clone(),
3219                        )
3220                    },
3221                )
3222            } else {
3223                // Checks if all elements in self has a matching element
3224                // in other
3225                let all_matches = self.iter().all(|(n, (i, j))| {
3226                    n.relative_eq(
3227                        other
3228                            .get(i.to_usize().unwrap(), j.to_usize().unwrap())
3229                            .unwrap_or(&N::zero()),
3230                        epsilon.clone(),
3231                        max_relative.clone(),
3232                    )
3233                });
3234                if !all_matches {
3235                    return false;
3236                }
3237
3238                // Must also check if all elements in other matches self
3239                other.iter().all(|(n, (i, j))| {
3240                    n.relative_eq(
3241                        self.get(i.to_usize().unwrap(), j.to_usize().unwrap())
3242                            .unwrap_or(&N::zero()),
3243                        epsilon.clone(),
3244                        max_relative.clone(),
3245                    )
3246                })
3247            }
3248        }
3249    }
3250
3251    #[cfg(test)]
3252    mod tests {
3253        use crate::*;
3254
3255        #[test]
3256        fn different_shapes() {
3257            let mut m1 = TriMat::new((3, 2));
3258            m1.add_triplet(1, 1, 8_u8);
3259            let m1: CsMat<_> = m1.to_csr();
3260            let mut m2 = TriMat::new((2, 3));
3261            m2.add_triplet(1, 1, 8_u8);
3262            let m2 = m2.to_csr();
3263
3264            ::approx::assert_abs_diff_ne!(m1, m2);
3265            ::approx::assert_abs_diff_ne!(m1, m2.to_csc());
3266            ::approx::assert_abs_diff_ne!(m1.to_csc(), m2);
3267            ::approx::assert_abs_diff_ne!(m1.to_csc(), m2.to_csc());
3268        }
3269
3270        #[test]
3271        fn equal_elements() {
3272            let mut m1 = TriMat::new((6, 9));
3273            m1.add_triplet(1, 1, 8_u8);
3274            m1.add_triplet(1, 2, 7_u8);
3275            m1.add_triplet(0, 1, 6_u8);
3276            m1.add_triplet(0, 8, 5_u8);
3277            m1.add_triplet(4, 2, 4_u8);
3278
3279            let m1: CsMat<_> = m1.to_csr();
3280            let m2 = m1.clone();
3281
3282            ::approx::assert_abs_diff_eq!(m1, m2, epsilon = 0);
3283            ::approx::assert_abs_diff_eq!(m1.to_csc(), m2, epsilon = 0);
3284            ::approx::assert_abs_diff_eq!(m1, m2.to_csc(), epsilon = 0);
3285            ::approx::assert_abs_diff_eq!(
3286                m1.to_csc(),
3287                m2.to_csc(),
3288                epsilon = 0
3289            );
3290
3291            let mut m1 = TriMat::new((6, 9));
3292            m1.add_triplet(1, 1, 8.0_f32);
3293            m1.add_triplet(1, 2, 7.0);
3294            m1.add_triplet(0, 1, 6.0);
3295            m1.add_triplet(0, 8, 5.0);
3296            m1.add_triplet(4, 2, 4.0);
3297
3298            let m1: CsMat<_> = m1.to_csr();
3299            let m2 = m1.clone();
3300
3301            ::approx::assert_abs_diff_eq!(m1, m2);
3302            ::approx::assert_abs_diff_eq!(m1.to_csc(), m2);
3303            ::approx::assert_abs_diff_eq!(m1, m2.to_csc());
3304            ::approx::assert_abs_diff_eq!(m1.to_csc(), m2.to_csc());
3305
3306            ::approx::assert_relative_eq!(m1, m2);
3307            ::approx::assert_relative_eq!(m1.to_csc(), m2);
3308            ::approx::assert_relative_eq!(m1, m2.to_csc());
3309            ::approx::assert_relative_eq!(m1.to_csc(), m2.to_csc());
3310
3311            ::approx::assert_ulps_eq!(m1, m2);
3312            ::approx::assert_ulps_eq!(m1.to_csc(), m2);
3313            ::approx::assert_ulps_eq!(m1, m2.to_csc());
3314            ::approx::assert_ulps_eq!(m1.to_csc(), m2.to_csc());
3315        }
3316
3317        #[test]
3318        fn almost_equal_elements() {
3319            let mut m1 = TriMat::new((6, 9));
3320            m1.add_triplet(1, 1, 8.0_f32);
3321            m1.add_triplet(1, 2, 7.0);
3322            m1.add_triplet(0, 1, 6.0);
3323            m1.add_triplet(0, 8, 5.0);
3324            m1.add_triplet(4, 2, 4.0);
3325            let m1: CsMat<_> = m1.to_csr();
3326
3327            let mut m2 = TriMat::new((6, 9));
3328            m2.add_triplet(1, 1, 8.0_f32);
3329            m2.add_triplet(1, 2, 7.0 - 0.5); // 0.5 subtracted
3330            m2.add_triplet(0, 1, 6.0);
3331            m2.add_triplet(0, 8, 5.0);
3332            m2.add_triplet(4, 2, 4.0);
3333            m2.add_triplet(4, 3, 0.2); // extra element
3334            let m2 = m2.to_csr();
3335
3336            ::approx::assert_abs_diff_eq!(m1, m2, epsilon = 0.6);
3337            ::approx::assert_abs_diff_eq!(m1.to_csc(), m2, epsilon = 0.6);
3338            ::approx::assert_abs_diff_eq!(m1, m2.to_csc(), epsilon = 0.6);
3339            ::approx::assert_abs_diff_eq!(
3340                m1.to_csc(),
3341                m2.to_csc(),
3342                epsilon = 0.6
3343            );
3344
3345            ::approx::assert_abs_diff_ne!(m1, m2, epsilon = 0.4);
3346            ::approx::assert_abs_diff_ne!(m1.to_csc(), m2, epsilon = 0.4);
3347            ::approx::assert_abs_diff_ne!(m1, m2.to_csc(), epsilon = 0.4);
3348            ::approx::assert_abs_diff_ne!(
3349                m1.to_csc(),
3350                m2.to_csc(),
3351                epsilon = 0.4
3352            );
3353        }
3354    }
3355}