sprs_rssn/sparse/
triplet_iter.rs

1//! A structure for iterating over the non-zero values of any kind of
2//! sparse matrix.
3
4use std::ops::Add;
5
6use crate::indexing::SpIndex;
7use crate::sparse::{CsMatI, TriMatIter};
8use crate::CompressedStorage;
9
10impl<'a, N, I, RI, CI, DI> Iterator for TriMatIter<RI, CI, DI>
11where
12    I: 'a + SpIndex,
13    N: 'a,
14    RI: Iterator<Item = &'a I>,
15    CI: Iterator<Item = &'a I>,
16    DI: Iterator<Item = &'a N>,
17{
18    type Item = (&'a N, (I, I));
19
20    fn next(&mut self) -> Option<<Self as Iterator>::Item> {
21        match (self.row_inds.next(), self.col_inds.next(), self.data.next()) {
22            (Some(row), Some(col), Some(val)) => Some((val, (*row, *col))),
23            _ => None,
24        }
25    }
26
27    fn size_hint(&self) -> (usize, Option<usize>) {
28        self.row_inds.size_hint() // FIXME merge hints?
29    }
30}
31
32impl<'a, N, I, RI, CI, DI> TriMatIter<RI, CI, DI>
33where
34    I: 'a + SpIndex,
35    N: 'a,
36    RI: Iterator<Item = &'a I>,
37    CI: Iterator<Item = &'a I>,
38    DI: Iterator<Item = &'a N>,
39{
40    /// Create a new `TriMatIter` from iterators
41    pub fn new(
42        shape: (usize, usize),
43        nnz: usize,
44        row_inds: RI,
45        col_inds: CI,
46        data: DI,
47    ) -> Self {
48        Self {
49            rows: shape.0,
50            cols: shape.1,
51            nnz,
52            row_inds,
53            col_inds,
54            data,
55        }
56    }
57
58    /// The number of rows of the matrix
59    pub fn rows(&self) -> usize {
60        self.rows
61    }
62
63    /// The number of cols of the matrix
64    pub fn cols(&self) -> usize {
65        self.cols
66    }
67
68    /// The shape of the matrix, as a `(rows, cols)` tuple
69    pub fn shape(&self) -> (usize, usize) {
70        (self.rows, self.cols)
71    }
72
73    /// The number of non-zero entries
74    pub fn nnz(&self) -> usize {
75        self.nnz
76    }
77
78    pub fn into_row_inds(self) -> RI {
79        self.row_inds
80    }
81
82    pub fn into_col_inds(self) -> CI {
83        self.col_inds
84    }
85
86    pub fn into_data(self) -> DI {
87        self.data
88    }
89
90    pub fn transpose_into(self) -> TriMatIter<CI, RI, DI> {
91        TriMatIter {
92            rows: self.cols,
93            cols: self.rows,
94            nnz: self.nnz,
95            row_inds: self.col_inds,
96            col_inds: self.row_inds,
97            data: self.data,
98        }
99    }
100}
101
102impl<'a, N, I, RI, CI, DI> TriMatIter<RI, CI, DI>
103where
104    I: 'a + SpIndex,
105    N: 'a + Clone,
106    RI: Clone + Iterator<Item = &'a I>,
107    CI: Clone + Iterator<Item = &'a I>,
108    DI: Clone + Iterator<Item = &'a N>,
109{
110    /// Consume `TriMatIter` and produce a CSC matrix
111    pub fn into_csc<Iptr: SpIndex>(self) -> CsMatI<N, I, Iptr>
112    where
113        N: Add<Output = N>,
114    {
115        self.into_cs(CompressedStorage::CSC)
116    }
117
118    /// Consume `TriMatIter` and produce a CSR matrix
119    pub fn into_csr<Iptr: SpIndex>(self) -> CsMatI<N, I, Iptr>
120    where
121        N: Add<Output = N>,
122    {
123        self.into_cs(CompressedStorage::CSR)
124    }
125
126    /// Consume `TriMatIter` and produce a `CsMat` matrix with the chosen storage
127    pub fn into_cs<Iptr: SpIndex>(
128        self,
129        storage: crate::CompressedStorage,
130    ) -> CsMatI<N, I, Iptr>
131    where
132        N: Add<Output = N>,
133    {
134        // (i,j, input position, output position)
135        let mut rc: Vec<(I, I, N)> = Vec::new();
136
137        let mut nnz_max = 0;
138        for (v, (i, j)) in self.clone() {
139            rc.push((i, j, v.clone()));
140            nnz_max += 1;
141        }
142
143        match storage {
144            CompressedStorage::CSR => {
145                rc.sort_unstable_by_key(|i| (i.0, i.1));
146            }
147            CompressedStorage::CSC => {
148                rc.sort_unstable_by_key(|i| (i.1, i.0));
149            }
150        }
151
152        let outer_idx = |idx_r: I, idx_c: I| match storage {
153            CompressedStorage::CSR => idx_r,
154            CompressedStorage::CSC => idx_c,
155        };
156
157        let outer_dims = match storage {
158            CompressedStorage::CSR => self.rows(),
159            CompressedStorage::CSC => self.cols(),
160        };
161
162        let mut slot = 0;
163        let mut indptr = vec![Iptr::zero(); outer_dims + 1];
164        let mut cur_outer = I::zero();
165
166        for rec in 0..nnz_max {
167            if rec > 0 {
168                if rc[rec - 1].0 == rc[rec].0 && rc[rec - 1].1 == rc[rec].1 {
169                    // got a duplicate - add the value in the current slot.
170                    rc[slot].2 = rc[slot].2.clone() + rc[rec].2.clone();
171                } else {
172                    // new cell -- fill it out
173                    slot += 1;
174                    rc[slot] = rc[rec].clone();
175                }
176            }
177
178            let new_outer = outer_idx(rc[rec].0, rc[rec].1);
179
180            while new_outer > cur_outer {
181                indptr[cur_outer.index() + 1] = Iptr::from_usize(slot);
182                cur_outer += I::one();
183            }
184        }
185
186        // Ensure that slot == nnz
187        if nnz_max > 0 {
188            slot += 1;
189        }
190        // fill indptr up to the end
191        while I::from_usize(outer_dims) > cur_outer {
192            indptr[cur_outer.index() + 1] = Iptr::from_usize(slot);
193            cur_outer += I::one();
194        }
195
196        rc.truncate(slot);
197
198        let mut data: Vec<N> = Vec::with_capacity(slot);
199        let mut indices: Vec<I> = vec![I::zero(); slot];
200
201        for (n, (i, j, v)) in rc.into_iter().enumerate() {
202            assert!({
203                let outer = outer_idx(i, j);
204                n >= indptr[outer.index()].index()
205                    && n < indptr[outer.index() + 1].index()
206            });
207
208            data.push(v);
209
210            match storage {
211                CompressedStorage::CSR => indices[n] = j,
212                CompressedStorage::CSC => indices[n] = i,
213            }
214        }
215
216        CsMatI {
217            storage,
218            nrows: self.rows,
219            ncols: self.cols,
220            indptr: crate::IndPtr::new_trusted(indptr),
221            indices,
222            data,
223        }
224    }
225}