1use std::ops::Add;
5
6use crate::indexing::SpIndex;
7use crate::sparse::{CsMatI, TriMatIter};
8use crate::CompressedStorage;
9
10impl<'a, N, I, RI, CI, DI> Iterator for TriMatIter<RI, CI, DI>
11where
12 I: 'a + SpIndex,
13 N: 'a,
14 RI: Iterator<Item = &'a I>,
15 CI: Iterator<Item = &'a I>,
16 DI: Iterator<Item = &'a N>,
17{
18 type Item = (&'a N, (I, I));
19
20 fn next(&mut self) -> Option<<Self as Iterator>::Item> {
21 match (self.row_inds.next(), self.col_inds.next(), self.data.next()) {
22 (Some(row), Some(col), Some(val)) => Some((val, (*row, *col))),
23 _ => None,
24 }
25 }
26
27 fn size_hint(&self) -> (usize, Option<usize>) {
28 self.row_inds.size_hint() }
30}
31
32impl<'a, N, I, RI, CI, DI> TriMatIter<RI, CI, DI>
33where
34 I: 'a + SpIndex,
35 N: 'a,
36 RI: Iterator<Item = &'a I>,
37 CI: Iterator<Item = &'a I>,
38 DI: Iterator<Item = &'a N>,
39{
40 pub fn new(
42 shape: (usize, usize),
43 nnz: usize,
44 row_inds: RI,
45 col_inds: CI,
46 data: DI,
47 ) -> Self {
48 Self {
49 rows: shape.0,
50 cols: shape.1,
51 nnz,
52 row_inds,
53 col_inds,
54 data,
55 }
56 }
57
58 pub fn rows(&self) -> usize {
60 self.rows
61 }
62
63 pub fn cols(&self) -> usize {
65 self.cols
66 }
67
68 pub fn shape(&self) -> (usize, usize) {
70 (self.rows, self.cols)
71 }
72
73 pub fn nnz(&self) -> usize {
75 self.nnz
76 }
77
78 pub fn into_row_inds(self) -> RI {
79 self.row_inds
80 }
81
82 pub fn into_col_inds(self) -> CI {
83 self.col_inds
84 }
85
86 pub fn into_data(self) -> DI {
87 self.data
88 }
89
90 pub fn transpose_into(self) -> TriMatIter<CI, RI, DI> {
91 TriMatIter {
92 rows: self.cols,
93 cols: self.rows,
94 nnz: self.nnz,
95 row_inds: self.col_inds,
96 col_inds: self.row_inds,
97 data: self.data,
98 }
99 }
100}
101
102impl<'a, N, I, RI, CI, DI> TriMatIter<RI, CI, DI>
103where
104 I: 'a + SpIndex,
105 N: 'a + Clone,
106 RI: Clone + Iterator<Item = &'a I>,
107 CI: Clone + Iterator<Item = &'a I>,
108 DI: Clone + Iterator<Item = &'a N>,
109{
110 pub fn into_csc<Iptr: SpIndex>(self) -> CsMatI<N, I, Iptr>
112 where
113 N: Add<Output = N>,
114 {
115 self.into_cs(CompressedStorage::CSC)
116 }
117
118 pub fn into_csr<Iptr: SpIndex>(self) -> CsMatI<N, I, Iptr>
120 where
121 N: Add<Output = N>,
122 {
123 self.into_cs(CompressedStorage::CSR)
124 }
125
126 pub fn into_cs<Iptr: SpIndex>(
128 self,
129 storage: crate::CompressedStorage,
130 ) -> CsMatI<N, I, Iptr>
131 where
132 N: Add<Output = N>,
133 {
134 let mut rc: Vec<(I, I, N)> = Vec::new();
136
137 let mut nnz_max = 0;
138 for (v, (i, j)) in self.clone() {
139 rc.push((i, j, v.clone()));
140 nnz_max += 1;
141 }
142
143 match storage {
144 CompressedStorage::CSR => {
145 rc.sort_unstable_by_key(|i| (i.0, i.1));
146 }
147 CompressedStorage::CSC => {
148 rc.sort_unstable_by_key(|i| (i.1, i.0));
149 }
150 }
151
152 let outer_idx = |idx_r: I, idx_c: I| match storage {
153 CompressedStorage::CSR => idx_r,
154 CompressedStorage::CSC => idx_c,
155 };
156
157 let outer_dims = match storage {
158 CompressedStorage::CSR => self.rows(),
159 CompressedStorage::CSC => self.cols(),
160 };
161
162 let mut slot = 0;
163 let mut indptr = vec![Iptr::zero(); outer_dims + 1];
164 let mut cur_outer = I::zero();
165
166 for rec in 0..nnz_max {
167 if rec > 0 {
168 if rc[rec - 1].0 == rc[rec].0 && rc[rec - 1].1 == rc[rec].1 {
169 rc[slot].2 = rc[slot].2.clone() + rc[rec].2.clone();
171 } else {
172 slot += 1;
174 rc[slot] = rc[rec].clone();
175 }
176 }
177
178 let new_outer = outer_idx(rc[rec].0, rc[rec].1);
179
180 while new_outer > cur_outer {
181 indptr[cur_outer.index() + 1] = Iptr::from_usize(slot);
182 cur_outer += I::one();
183 }
184 }
185
186 if nnz_max > 0 {
188 slot += 1;
189 }
190 while I::from_usize(outer_dims) > cur_outer {
192 indptr[cur_outer.index() + 1] = Iptr::from_usize(slot);
193 cur_outer += I::one();
194 }
195
196 rc.truncate(slot);
197
198 let mut data: Vec<N> = Vec::with_capacity(slot);
199 let mut indices: Vec<I> = vec![I::zero(); slot];
200
201 for (n, (i, j, v)) in rc.into_iter().enumerate() {
202 assert!({
203 let outer = outer_idx(i, j);
204 n >= indptr[outer.index()].index()
205 && n < indptr[outer.index() + 1].index()
206 });
207
208 data.push(v);
209
210 match storage {
211 CompressedStorage::CSR => indices[n] = j,
212 CompressedStorage::CSC => indices[n] = i,
213 }
214 }
215
216 CsMatI {
217 storage,
218 nrows: self.rows,
219 ncols: self.cols,
220 indptr: crate::IndPtr::new_trusted(indptr),
221 indices,
222 data,
223 }
224 }
225}