1use crate::indexing::SpIndex;
12use crate::sparse::prelude::*;
13
14use std::ops::{Add, Deref, DerefMut};
15use std::slice::Iter;
16
17#[derive(Clone, Copy, PartialEq, Eq, Debug)]
19pub struct TripletIndex(pub usize);
20
21impl<'a, N, I, IS, DS> IntoIterator for &'a TriMatBase<IS, DS>
22where
23 I: 'a + SpIndex,
24 N: 'a,
25 IS: Deref<Target = [I]>,
26 DS: Deref<Target = [N]>,
27{
28 type Item = (&'a N, (I, I));
29 type IntoIter = TriMatIter<Iter<'a, I>, Iter<'a, I>, Iter<'a, N>>;
30 fn into_iter(self) -> Self::IntoIter {
31 self.triplet_iter()
32 }
33}
34
35impl<'a, N, I> IntoIterator for TriMatViewI<'a, N, I>
36where
37 I: SpIndex,
38{
39 type Item = (&'a N, (I, I));
40 type IntoIter = TriMatIter<Iter<'a, I>, Iter<'a, I>, Iter<'a, N>>;
41 fn into_iter(self) -> Self::IntoIter {
42 self.triplet_iter_rbr()
43 }
44}
45
46impl<N, I, IS, DS> SparseMat for TriMatBase<IS, DS>
47where
48 I: SpIndex,
49 IS: Deref<Target = [I]>,
50 DS: Deref<Target = [N]>,
51{
52 fn rows(&self) -> usize {
53 self.rows()
54 }
55
56 fn cols(&self) -> usize {
57 self.cols()
58 }
59
60 fn nnz(&self) -> usize {
61 self.nnz()
62 }
63}
64
65impl<'a, N, I, IS, DS> SparseMat for &'a TriMatBase<IS, DS>
66where
67 I: 'a + SpIndex,
68 N: 'a,
69 IS: Deref<Target = [I]>,
70 DS: Deref<Target = [N]>,
71{
72 fn rows(&self) -> usize {
73 (*self).rows()
74 }
75
76 fn cols(&self) -> usize {
77 (*self).cols()
78 }
79
80 fn nnz(&self) -> usize {
81 (*self).nnz()
82 }
83}
84
85impl<N, I: SpIndex> TriMatI<N, I> {
87 pub fn new(shape: (usize, usize)) -> Self {
89 Self {
90 rows: shape.0,
91 cols: shape.1,
92 row_inds: Vec::new(),
93 col_inds: Vec::new(),
94 data: Vec::new(),
95 }
96 }
97
98 pub fn with_capacity(shape: (usize, usize), cap: usize) -> Self {
101 Self {
102 rows: shape.0,
103 cols: shape.1,
104 row_inds: Vec::with_capacity(cap),
105 col_inds: Vec::with_capacity(cap),
106 data: Vec::with_capacity(cap),
107 }
108 }
109
110 pub fn from_triplets(
118 shape: (usize, usize),
119 row_inds: Vec<I>,
120 col_inds: Vec<I>,
121 data: Vec<N>,
122 ) -> Self {
123 assert_eq!(
124 row_inds.len(),
125 col_inds.len(),
126 "all inputs should have the same length"
127 );
128 assert_eq!(
129 data.len(),
130 col_inds.len(),
131 "all inputs should have the same length"
132 );
133 assert_eq!(
134 row_inds.len(),
135 data.len(),
136 "all inputs should have the same length"
137 );
138 assert!(
139 row_inds.iter().all(|&i| i.index() < shape.0),
140 "row indices should be within shape"
141 );
142 assert!(
143 col_inds.iter().all(|&j| j.index() < shape.1),
144 "col indices should be within shape"
145 );
146 Self {
147 rows: shape.0,
148 cols: shape.1,
149 row_inds,
150 col_inds,
151 data,
152 }
153 }
154
155 pub fn add_triplet(&mut self, row: usize, col: usize, val: N) {
157 assert!(row < self.rows);
158 assert!(col < self.cols);
159 self.row_inds.push(I::from_usize(row));
160 self.col_inds.push(I::from_usize(col));
161 self.data.push(val);
162 }
163
164 pub fn reserve(&mut self, cap: usize) {
166 self.row_inds.reserve(cap);
167 self.col_inds.reserve(cap);
168 self.data.reserve(cap);
169 }
170
171 pub fn reserve_exact(&mut self, cap: usize) {
173 self.row_inds.reserve_exact(cap);
174 self.col_inds.reserve_exact(cap);
175 self.data.reserve_exact(cap);
176 }
177}
178
179impl<N, I: SpIndex, IStorage, DStorage> TriMatBase<IStorage, DStorage>
181where
182 IStorage: Deref<Target = [I]>,
183 DStorage: Deref<Target = [N]>,
184{
185 pub fn rows(&self) -> usize {
187 self.rows
188 }
189
190 pub fn cols(&self) -> usize {
192 self.cols
193 }
194
195 pub fn shape(&self) -> (usize, usize) {
197 (self.rows, self.cols)
198 }
199
200 pub fn nnz(&self) -> usize {
202 self.data.len()
203 }
204
205 pub fn row_inds(&self) -> &[I] {
207 &self.row_inds[..]
208 }
209
210 pub fn col_inds(&self) -> &[I] {
212 &self.col_inds[..]
213 }
214
215 pub fn data(&self) -> &[N] {
217 &self.data[..]
218 }
219
220 pub fn find_locations(&self, row: usize, col: usize) -> Vec<TripletIndex> {
222 self.row_inds
223 .iter()
224 .zip(self.col_inds.iter())
225 .enumerate()
226 .filter_map(|(ind, (&i, &j))| {
227 if i.index_unchecked() == row && j.index_unchecked() == col {
228 Some(TripletIndex(ind))
229 } else {
230 None
231 }
232 })
233 .collect()
234 }
235
236 pub fn transpose_view(&self) -> TriMatViewI<'_, N, I> {
238 TriMatViewI {
239 rows: self.cols,
240 cols: self.rows,
241 row_inds: &self.col_inds[..],
242 col_inds: &self.row_inds[..],
243 data: &self.data[..],
244 }
245 }
246
247 pub fn triplet_iter(
249 &self,
250 ) -> TriMatIter<Iter<'_, I>, Iter<'_, I>, Iter<'_, N>> {
251 TriMatIter {
252 rows: self.rows,
253 cols: self.cols,
254 nnz: self.nnz(),
255 row_inds: self.row_inds.iter(),
256 col_inds: self.col_inds.iter(),
257 data: self.data.iter(),
258 }
259 }
260
261 pub fn to_csc<Iptr: SpIndex>(&self) -> CsMatI<N, I, Iptr>
263 where
264 N: Clone + Add<Output = N>,
265 {
266 self.triplet_iter().into_csc()
267 }
268
269 pub fn to_csr<Iptr: SpIndex>(&self) -> CsMatI<N, I, Iptr>
271 where
272 N: Clone + Add<Output = N>,
273 {
274 self.triplet_iter().into_csr()
275 }
276
277 pub fn view(&self) -> TriMatViewI<'_, N, I> {
278 TriMatViewI {
279 rows: self.rows,
280 cols: self.cols,
281 row_inds: &self.row_inds[..],
282 col_inds: &self.col_inds[..],
283 data: &self.data[..],
284 }
285 }
286}
287
288impl<'a, N, I: SpIndex> TriMatBase<&'a [I], &'a [N]> {
289 pub fn triplet_iter_rbr(
293 &self,
294 ) -> TriMatIter<Iter<'a, I>, Iter<'a, I>, Iter<'a, N>> {
295 TriMatIter {
296 rows: self.rows,
297 cols: self.cols,
298 nnz: self.nnz(),
299 row_inds: self.row_inds.iter(),
300 col_inds: self.col_inds.iter(),
301 data: self.data.iter(),
302 }
303 }
304}
305
306impl<N, I: SpIndex, IStorage, DStorage> TriMatBase<IStorage, DStorage>
307where
308 IStorage: DerefMut<Target = [I]>,
309 DStorage: DerefMut<Target = [N]>,
310{
311 pub fn set_triplet(
314 &mut self,
315 TripletIndex(triplet_ind): TripletIndex,
316 row: usize,
317 col: usize,
318 val: N,
319 ) {
320 self.row_inds[triplet_ind] = I::from_usize(row);
321 self.col_inds[triplet_ind] = I::from_usize(col);
322 self.data[triplet_ind] = val;
323 }
324
325 pub fn view_mut(&mut self) -> TriMatViewMutI<'_, N, I> {
326 TriMatViewMutI {
327 rows: self.rows,
328 cols: self.cols,
329 row_inds: &mut self.row_inds[..],
330 col_inds: &mut self.col_inds[..],
331 data: &mut self.data[..],
332 }
333 }
334}
335
336#[cfg(test)]
337mod test {
338
339 use super::{TriMat, TriMatI};
340 use crate::sparse::{CsMat, CsMatI};
341
342 #[test]
343 fn triplet_incremental() {
344 let mut triplet_mat = TriMatI::with_capacity((4, 4), 6);
345 triplet_mat.add_triplet(0, 0, 1.);
350 triplet_mat.add_triplet(0, 1, 2.);
351 triplet_mat.add_triplet(1, 0, 3.);
352 triplet_mat.add_triplet(2, 3, 4.);
353 triplet_mat.add_triplet(3, 2, 5.);
354 triplet_mat.add_triplet(3, 3, 6.);
355
356 let csc: CsMatI<_, i32> = triplet_mat.to_csc();
357 let expected = CsMatI::new_csc(
358 (4, 4),
359 vec![0, 2, 3, 4, 6],
360 vec![0, 1, 0, 3, 2, 3],
361 vec![1., 3., 2., 5., 4., 6.],
362 );
363 assert_eq!(csc, expected);
364 }
365
366 #[test]
367 fn triplet_unordered() {
368 let mut triplet_mat = TriMat::with_capacity((4, 4), 6);
369 triplet_mat.add_triplet(0, 1, 2.);
378 triplet_mat.add_triplet(0, 0, 1.);
379 triplet_mat.add_triplet(1, 0, 3.);
380 triplet_mat.add_triplet(2, 3, 4.);
381 triplet_mat.add_triplet(3, 3, 6.);
382 triplet_mat.add_triplet(3, 2, 5.);
383
384 let expected = CsMat::new_csc(
385 (4, 4),
386 vec![0, 2, 3, 4, 6],
387 vec![0, 1, 0, 3, 2, 3],
388 vec![1., 3., 2., 5., 4., 6.],
389 );
390
391 let csc = triplet_mat.to_csc();
392 assert_eq!(csc, expected);
393
394 let csr_to_csc = triplet_mat.to_csr().to_csc();
395 assert_eq!(csr_to_csc, expected);
396 }
397
398 #[test]
399 fn triplet_additions() {
400 let mut triplet_mat = TriMat::with_capacity((4, 4), 6);
401 triplet_mat.add_triplet(0, 1, 2.);
409 triplet_mat.add_triplet(0, 0, 1.);
410 triplet_mat.add_triplet(3, 2, 3.);
411 triplet_mat.add_triplet(1, 0, 3.);
412 triplet_mat.add_triplet(2, 3, 4.);
413 triplet_mat.add_triplet(3, 3, 6.);
414 triplet_mat.add_triplet(3, 2, 2.);
415
416 let csc = triplet_mat.to_csc();
417 let csr = triplet_mat.to_csr();
418 let expected = CsMat::new_csc(
419 (4, 4),
420 vec![0, 2, 3, 4, 6],
421 vec![0, 1, 0, 3, 2, 3],
422 vec![1., 3., 2., 5., 4., 6.],
423 );
424 assert_eq!(csc, expected);
425 assert_eq!(csr, expected.to_csr());
426 }
427
428 #[test]
429 fn triplet_from_vecs() {
430 let row_inds = vec![0, 0, 1, 2, 3, 3, 4, 4];
436 let col_inds = vec![0, 1, 0, 3, 2, 3, 1, 3];
437 let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
438
439 let triplet_mat =
440 super::TriMat::from_triplets((5, 4), row_inds, col_inds, data);
441
442 let csc = triplet_mat.to_csc();
443 let csr = triplet_mat.to_csr();
444 let expected = CsMat::new_csc(
445 (5, 4),
446 vec![0, 2, 4, 5, 8],
447 vec![0, 1, 0, 4, 3, 2, 3, 4],
448 vec![1, 3, 2, 7, 5, 4, 6, 8],
449 );
450
451 assert_eq!(csc, expected);
452 assert_eq!(csr, expected.to_csr());
453 }
454
455 #[test]
456 fn triplet_mutate_entry() {
457 let mut triplet_mat = TriMat::with_capacity((4, 4), 6);
458 triplet_mat.add_triplet(0, 0, 1.);
459 triplet_mat.add_triplet(0, 1, 2.);
460 triplet_mat.add_triplet(1, 0, 3.);
461 triplet_mat.add_triplet(2, 3, 4.);
462 triplet_mat.add_triplet(3, 2, 5.);
463 triplet_mat.add_triplet(3, 3, 6.);
464
465 let locations = triplet_mat.find_locations(2, 3);
466 assert_eq!(locations.len(), 1);
467 triplet_mat.set_triplet(locations[0], 2, 3, 0.);
468
469 let csc = triplet_mat.to_csc();
470 let csr = triplet_mat.to_csr();
471 let expected = CsMat::new_csc(
472 (4, 4),
473 vec![0, 2, 3, 4, 6],
474 vec![0, 1, 0, 3, 2, 3],
475 vec![1., 3., 2., 5., 0., 6.],
476 );
477 assert_eq!(csc, expected);
478 assert_eq!(csr, expected.to_csr());
479 }
480
481 #[test]
482 fn triplet_to_csr() {
483 let mut triplet_mat = TriMat::with_capacity((4, 4), 6);
484 triplet_mat.add_triplet(0, 1, 2.);
492 triplet_mat.add_triplet(0, 0, 1.);
493 triplet_mat.add_triplet(3, 2, 3.);
494 triplet_mat.add_triplet(1, 0, 3.);
495 triplet_mat.add_triplet(2, 3, 4.);
496 triplet_mat.add_triplet(3, 3, 6.);
497 triplet_mat.add_triplet(3, 2, 2.);
498
499 let csr = triplet_mat.to_csr();
500 let csc = triplet_mat.to_csc();
501
502 let expected = CsMat::new_csc(
503 (4, 4),
504 vec![0, 2, 3, 4, 6],
505 vec![0, 1, 0, 3, 2, 3],
506 vec![1., 3., 2., 5., 4., 6.],
507 );
508
509 assert_eq!(csc, expected);
510 assert_eq!(csr, expected.to_csr());
511 }
512
513 #[test]
514 fn triplet_complex() {
515 let mut triplet_mat = TriMat::with_capacity((6, 9), 22);
522
523 triplet_mat.add_triplet(5, 8, 1); triplet_mat.add_triplet(0, 0, 1);
525 triplet_mat.add_triplet(0, 8, 2);
526 triplet_mat.add_triplet(0, 4, 2); triplet_mat.add_triplet(2, 0, 1);
528 triplet_mat.add_triplet(2, 1, 2);
529 triplet_mat.add_triplet(2, 3, 2); triplet_mat.add_triplet(2, 6, 3);
531 triplet_mat.add_triplet(2, 8, 2);
532 triplet_mat.add_triplet(1, 0, 1);
533 triplet_mat.add_triplet(1, 5, 1);
534 triplet_mat.add_triplet(1, 8, 1); triplet_mat.add_triplet(0, 4, 4); triplet_mat.add_triplet(3, 8, 2);
537 triplet_mat.add_triplet(3, 5, 4);
538 triplet_mat.add_triplet(5, 8, 1); triplet_mat.add_triplet(3, 2, 9);
540 triplet_mat.add_triplet(3, 0, 1);
541 triplet_mat.add_triplet(4, 0, 1);
542 triplet_mat.add_triplet(4, 8, 2);
543 triplet_mat.add_triplet(1, 8, 1); triplet_mat.add_triplet(4, 3, 5);
545 triplet_mat.add_triplet(5, 0, 1);
546 triplet_mat.add_triplet(5, 5, 7);
547 triplet_mat.add_triplet(2, 3, 1); triplet_mat.add_triplet(5, 7, 8);
549
550 let csc = triplet_mat.to_csc();
551
552 let expected = CsMat::new_csc(
553 (6, 9),
554 vec![0, 6, 7, 8, 10, 11, 14, 15, 16, 22],
555 vec![
556 0, 1, 2, 3, 4, 5, 2, 3, 2, 4, 0, 1, 3, 5, 2, 5, 0, 1, 2, 3, 4,
557 5,
558 ],
559 vec![
560 1, 1, 1, 1, 1, 1, 2, 9, 3, 5, 6, 1, 4, 7, 3, 8, 2, 2, 2, 2, 2,
561 2,
562 ],
563 );
564
565 assert_eq!(csc, expected);
566
567 let csr = triplet_mat.to_csr();
568 assert_eq!(csr, expected.to_csr());
569 }
570
571 #[test]
572 fn triplet_empty_lines() {
573 let tri_mat = TriMatI::new((2, 4));
575 let m: CsMat<u64> = tri_mat.to_csr();
576 assert_eq!(m.indptr(), &[0, 0, 0][..]);
577 assert_eq!(m.indices(), &[]);
578 assert_eq!(m.data(), &[]);
579
580 let m: CsMat<u64> = tri_mat.to_csc();
581 assert_eq!(m.indptr(), &[0, 0, 0, 0, 0][..]);
582 assert_eq!(m.indices(), &[]);
583 assert_eq!(m.data(), &[]);
584
585 let mut triplet_mat = TriMat::with_capacity((6, 9), 22);
593
594 triplet_mat.add_triplet(5, 8, 1); triplet_mat.add_triplet(0, 0, 1);
596 triplet_mat.add_triplet(0, 8, 2);
597 triplet_mat.add_triplet(0, 4, 2); triplet_mat.add_triplet(2, 0, 1);
599 triplet_mat.add_triplet(2, 1, 2);
600 triplet_mat.add_triplet(2, 3, 2); triplet_mat.add_triplet(2, 8, 2);
602 triplet_mat.add_triplet(0, 4, 4); triplet_mat.add_triplet(3, 8, 2);
604 triplet_mat.add_triplet(3, 5, 4);
605 triplet_mat.add_triplet(5, 8, 1); triplet_mat.add_triplet(3, 0, 1);
607 triplet_mat.add_triplet(4, 0, 1);
608 triplet_mat.add_triplet(4, 8, 2);
609 triplet_mat.add_triplet(4, 3, 5);
610 triplet_mat.add_triplet(5, 0, 1);
611 triplet_mat.add_triplet(5, 5, 7);
612 triplet_mat.add_triplet(2, 3, 1); let csc = triplet_mat.to_csc();
615
616 let expected = CsMat::new_csc(
617 (6, 9),
618 vec![0, 5, 6, 6, 8, 9, 11, 11, 11, 16],
619 vec![0, 2, 3, 4, 5, 2, 2, 4, 0, 3, 5, 0, 2, 3, 4, 5],
620 vec![1, 1, 1, 1, 1, 2, 3, 5, 6, 4, 7, 2, 2, 2, 2, 2],
621 );
622
623 assert_eq!(csc, expected);
624
625 let csr = triplet_mat.to_csr();
626 assert_eq!(csr, expected.to_csr());
627
628 let mut triplet_mat = TriMat::with_capacity((4, 6), 2);
634
635 triplet_mat.add_triplet(1, 1, 1);
636 triplet_mat.add_triplet(0, 3, 2);
637
638 let m = triplet_mat.to_csc();
639 assert_eq!(m.indptr(), &[0, 0, 1, 1, 2, 2, 2][..]);
640 assert_eq!(m.indices(), &[1, 0]);
641 assert_eq!(m.data(), &[1, 2]);
642 }
643}