1use std::ops::{Add, AddAssign, Neg, Sub, SubAssign, Mul, MulAssign, Range};
2use std::iter::zip;
3use std::fmt::{Display, Debug};
4use delegate::delegate;
5use nalgebra_sparse::na::{Scalar, ClosedAddAssign, ClosedSubAssign, ClosedMulAssign};
6use nalgebra_sparse::{CscMatrix, CooMatrix};
7use num_traits::{Zero, One, ToPrimitive};
8use auto_impl_ops::auto_ops;
9use sprs::PermView;
10use yui_core::{Ring, RingOps};
11use crate::dense::*;
12use super::sp_vec::SpVec;
13use super::triang::TriangularType;
14
15#[derive(Clone, PartialEq, Eq)]
16pub struct SpMat<R> {
17 inner: CscMatrix<R>
18}
19
20impl<R> MatTrait for SpMat<R> {
21 fn shape(&self) -> (usize, usize) {
22 (self.inner.nrows(), self.inner.ncols())
23 }
24}
25
26impl<R> SpMat<R> {
27 pub(crate) fn inner(&self) -> &CscMatrix<R> {
28 &self.inner
29 }
30
31 pub(crate) fn into_inner(self) -> CscMatrix<R> {
32 self.inner
33 }
34
35 pub fn data(&self) -> (&[usize], &[usize], &[R]) {
36 self.inner.csc_data()
37 }
38
39 pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<R>) {
40 self.inner.disassemble()
41 }
42
43 pub fn zero(shape: (usize, usize)) -> Self {
44 let csc = CscMatrix::zeros(shape.0, shape.1);
45 Self::from(csc)
46 }
47
48 pub fn is_zero(&self) -> bool
49 where R: Zero {
50 self.inner.values().iter().all(|a| a.is_zero())
51 }
52
53 pub fn id(n: usize) -> Self
54 where R: Scalar + One {
55 let csc = CscMatrix::identity(n);
56 Self::from(csc)
57 }
58
59 pub fn is_id(&self) -> bool
60 where R: Scalar + One + Zero {
61 self.is_square() && self.iter().all(|(i, j, a)|
62 (i == j && a.is_one()) || (i != j && a.is_zero())
63 )
64 }
65
66 pub fn is_triang(&self, t: TriangularType) -> bool
67 where R: Zero {
68 if self.nrows() != self.ncols() {
69 return false
70 }
71
72 if t.is_upper() {
73 self.iter_nz().all(|(i, j, _)| i <= j )
74 } else {
75 self.iter_nz().all(|(i, j, _)| i >= j )
76 }
77 }
78
79 pub fn iter(&self) -> impl Iterator<Item = (usize, usize, &R)> {
80 self.inner.triplet_iter()
81 }
82
83 pub fn iter_nz(&self) -> impl Iterator<Item = (usize, usize, &R)>
84 where R: Zero {
85 self.iter().filter(|e| !e.2.is_zero())
86 }
87
88 pub fn into_dense(self) -> Mat<R>
89 where R: Scalar + Zero + ClosedAddAssign {
90 self.into()
91 }
92
93 pub fn nnz(&self) -> usize {
94 self.inner.nnz()
95 }
96
97 pub fn density(&self) -> f64 {
98 let (m, n) = self.shape();
99 if m == 0 || n == 0 {
100 return 0.0
101 }
102
103 let nnz = self.nnz().to_f64().unwrap();
104 let total = (m * n).to_f64().unwrap();
105
106 nnz / total
107 }
108
109 pub fn redundancy(&self) -> f64
110 where R: Zero {
111 let nnz = self.nnz().to_f64().unwrap();
112 let red = self.iter().filter(|(_, _, a)| a.is_zero()).count().to_f64().unwrap();
113 red / nnz
114 }
115
116 pub fn mean_weight(&self) -> f64
117 where R: Ring, for<'x> &'x R: RingOps<R> {
118 let nnz = self.nnz().to_f64().unwrap();
119 let w = self.iter().map(|(_, _, a)| a.c_weight()).sum::<f64>();
120 w / nnz
121 }
122}
123
124impl<R> SpMat<R>
125where R: Scalar + Clone + Zero + ClosedAddAssign {
126 pub fn from_entries<T>(shape: (usize, usize), entries: T) -> Self
127 where T: IntoIterator<Item = (usize, usize, R)> {
128 let mut coo = CooMatrix::new(shape.0, shape.1);
129 for (i, j, a) in entries {
130 if a.is_zero() {
131 continue;
132 }
133 coo.push(i, j, a)
134 }
135 let csc = CscMatrix::from(&coo);
136 Self::from(csc)
137 }
138
139 pub fn from_col_vecs<I>(nrows: usize, vecs: I) -> Self
140 where I: IntoIterator<Item = SpVec<R>> {
141 let mut col_offsets = vec![0];
142 let mut row_indices = vec![];
143 let mut values = vec![];
144
145 for v in vecs.into_iter() {
146 assert_eq!(nrows, v.dim());
147 let (_, mut v_rows, mut v_values) = v.into_inner().disassemble();
148
149 row_indices.append(&mut v_rows);
150 values.append(&mut v_values);
151 col_offsets.push(row_indices.len());
152 }
153
154 let ncols = col_offsets.len() - 1;
155 let csc = CscMatrix::try_from_csc_data(nrows, ncols, col_offsets, row_indices, values).unwrap();
156 Self::from(csc)
157 }
158
159 pub fn from_dense_data<I>(shape: (usize, usize), data: I) -> Self
160 where I: IntoIterator<Item = R> {
161 let n = shape.1;
162 Self::from_entries(
163 shape,
164 data.into_iter().enumerate().map(|(k, a)| {
165 let (i, j) = (k / n, k % n);
166 (i, j, a)
167 })
168 )
169 }
170
171 pub fn col_vec(&self, j: usize) -> SpVec<R>
172 where R: Scalar + Zero + ClosedAddAssign {
173 let col = self.inner.col(j);
174 let iter = Iterator::zip(
175 col.row_indices().iter().cloned(),
176 col.values().iter().cloned()
177 );
178 SpVec::from_entries(self.nrows(), iter)
179 }
180
181 pub fn transpose(&self) -> Self {
182 self.inner.transpose().into()
183 }
184
185 pub fn extract<F>(&self, shape: (usize, usize), f: F) -> SpMat<R>
186 where F: Fn(usize, usize) -> Option<(usize, usize)> {
187 SpMat::from_entries(shape, self.iter().filter_map(|(i, j, a)|
188 f(i, j).map(|(i, j)| (i, j, a.clone()))
189 ))
190 }
191
192 pub fn permute(&self, p: PermView, q: PermView) -> SpMat<R> {
193 self.extract(self.shape(), |i, j| Some((p.at(i), q.at(j))))
194 }
195
196 pub fn permute_rows(&self, p: PermView) -> SpMat<R> {
197 let id = PermView::identity(self.ncols());
198 self.permute(p, id)
199 }
200
201 pub fn permute_cols(&self, q: PermView) -> SpMat<R> {
202 let id = PermView::identity(self.nrows());
203 self.permute(id, q)
204 }
205
206 pub fn submat(&self, rows: Range<usize>, cols: Range<usize>) -> SpMat<R> {
207 let (i0, i1) = (rows.start, rows.end);
208 let (j0, j1) = (cols.start, cols.end);
209
210 assert!(i0 <= i1 && i1 <= self.nrows());
211 assert!(j0 <= j1 && j1 <= self.ncols());
212
213 let shape = (i1 - i0, j1 - j0);
214 self.extract(shape, |i, j|
215 (rows.contains(&i) && cols.contains(&j)).then( ||
216 (i - i0, j - j0)
217 )
218 )
219 }
220
221 pub fn submat_rows(&self, rows: Range<usize>) -> SpMat<R> {
222 let n = self.ncols();
223 self.submat(rows, 0 .. n)
224 }
225
226 pub fn submat_cols(&self, cols: Range<usize>) -> SpMat<R> {
227 let m = self.nrows();
228 self.submat(0 .. m, cols)
229 }
230
231 pub fn divide4(&self, point: (usize, usize)) -> [SpMat<R>; 4] {
232 let (m, n) = self.shape();
233 let (k, l) = point;
234 assert!(k <= m);
235 assert!(l <= n);
236
237 let mut a = CooMatrix::new(k, l);
238 let mut b = CooMatrix::new(k, n - l);
239 let mut c = CooMatrix::new(m - k, l);
240 let mut d = CooMatrix::new(m - k, n - l);
241
242 for (i, j, r) in self.iter() {
243 if r.is_zero() { continue }
244 let r = r.clone();
245 match ((0..k).contains(&i), (0..l).contains(&j)) {
246 (true , true ) => a.push(i, j, r),
247 (true , false) => b.push(i, j - l, r),
248 (false, true ) => c.push(i - k, j, r),
249 (false, false) => d.push(i - k, j - l, r),
250 }
251 }
252
253 [a, b, c, d].map(|x|
254 CscMatrix::from(&x).into()
255 )
256 }
257
258 pub fn combine_blocks(blocks: [&SpMat<R>; 4]) -> SpMat<R> {
259 let [a, b, c, d] = blocks;
260
261 assert_eq!(a.nrows(), b.nrows());
262 assert_eq!(c.nrows(), d.nrows());
263 assert_eq!(a.ncols(), c.ncols());
264 assert_eq!(b.ncols(), d.ncols());
265
266 let (m, n) = (a.nrows() + c.nrows(), a.ncols() + b.ncols());
267 let (k, l) = a.shape();
268
269 let entries = zip(
270 [a, b, c, d],
271 [(0,0), (0,l), (k,0), (k,l)]
272 ).flat_map(|(x, (di, dj))|
273 x.iter().map(move |(i, j, r)|
274 (i + di, j + dj, r.clone())
275 )
276 );
277
278 Self::from_entries((m, n), entries)
279 }
280
281 pub fn concat(&self, b: &Self) -> Self {
282 let zero = |m, n| SpMat::<R>::zero((m, n));
283 Self::combine_blocks([
284 self,
285 b,
286 &zero(0, self.ncols()),
287 &zero(0, b.ncols())
288 ])
289 }
290
291 pub fn stack(&self, b: &Self) -> Self {
292 let zero = |m, n| SpMat::<R>::zero((m, n));
293 Self::combine_blocks([
294 self,
295 &zero(self.nrows(), 0),
296 b,
297 &zero(b.nrows(), 0)
298 ])
299 }
300
301 pub fn extend_cols(&mut self, b: Self) {
302 assert_eq!(self.nrows(), b.nrows());
303
304 if b.ncols() == 0 {
305 return
306 }
307
308 let shape = (self.nrows(), self.ncols() + b.ncols());
309 let l = std::mem::replace(&mut self.inner, CscMatrix::zeros(0, 0));
310 let r = b.inner;
311
312 let (mut col_offsets, mut row_indices, mut values) = l.disassemble();
313 let (c, mut r, mut v) = r.disassemble();
314
315 let offset = col_offsets.pop().unwrap(); col_offsets.extend(c.into_iter().map(|i| offset + i));
317 row_indices.append(&mut r);
318 values.append(&mut v);
319
320 self.inner = CscMatrix::try_from_csc_data(
321 shape.0, shape.1,
322 col_offsets,
323 row_indices,
324 values
325 ).unwrap();
326 }
327
328 pub fn from_row_perm(p: PermView) -> Self
330 where R: One {
331 let n = p.dim();
332 Self::from_entries((n, n), (0..n).map(|i|
333 (p.at(i), i, R::one())
334 ))
335 }
336
337 pub fn from_col_perm(p: PermView) -> Self
339 where R: One {
340 let n = p.dim();
341 Self::from_entries((n, n), (0..n).map(|i|
342 (i, p.at(i), R::one())
343 ))
344 }
345}
346
347impl<R> From<CscMatrix<R>> for SpMat<R> {
348 fn from(inner: CscMatrix<R>) -> Self {
349 Self { inner }
350 }
351}
352
353impl<R> From<Mat<R>> for SpMat<R>
354where R: Scalar + Zero {
355 fn from(value: Mat<R>) -> Self {
356 let csc = CscMatrix::from(value.inner());
357 Self::from(csc)
358 }
359}
360
361impl<R> Default for SpMat<R> {
362 fn default() -> Self {
363 Self::zero((0, 0))
364 }
365}
366
367impl<R> Neg for SpMat<R>
368where R: Scalar + Neg<Output = R> {
369 type Output = Self;
370 fn neg(self) -> Self::Output {
371 Self::from(-self.inner)
372 }
373}
374
375impl<R> Neg for &SpMat<R>
376where R: Scalar + Neg<Output = R> {
377 type Output = SpMat<R>;
378 fn neg(self) -> Self::Output {
379 SpMat::from(-&self.inner)
380 }
381}
382
383macro_rules! impl_binop {
385 ($trait:ident, $method:ident) => {
386 #[auto_ops]
387 impl<'a, 'b, R> $trait<&'b SpMat<R>> for &'a SpMat<R>
388 where R: Scalar + ClosedAddAssign + ClosedSubAssign + ClosedMulAssign + Zero + One + Neg<Output = R> {
389 type Output = SpMat<R>;
390 fn $method(self, rhs: &'b SpMat<R>) -> Self::Output {
391 let res = (&self.inner).$method(&rhs.inner);
392 SpMat::from(res)
393 }
394 }
395 };
396}
397
398impl_binop!(Add, add);
399impl_binop!(Sub, sub);
400impl_binop!(Mul, mul);
401
402impl<R> Display for SpMat<R>
403where R: Display + Debug {
404 delegate! { to self.inner {
405 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
406 }}
407}
408
409impl<R> Debug for SpMat<R>
410where R: Display + Debug {
411 delegate! { to self.inner {
412 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
413 }}
414}
415
416#[cfg(feature = "serde")]
417impl<R> serde::Serialize for SpMat<R>
418where R: Clone + serde::Serialize {
419 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
420 where S: serde::Serializer {
421 self.inner.serialize(serializer)
422 }
423}
424
425#[cfg(feature = "serde")]
426impl<'de, R> serde::Deserialize<'de> for SpMat<R>
427where R: Clone + serde::Deserialize<'de> {
428 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
429 where D: serde::Deserializer<'de> {
430 let inner = CscMatrix::deserialize(deserializer)?;
431 let res = Self::from(inner);
432 Ok(res)
433 }
434}
435
436#[cfg(test)]
437impl<R> SpMat<R>
438where R: Scalar + Zero + One + ClosedAddAssign {
439 pub fn rand(shape: (usize, usize), density: f64) -> Self {
440 use cartesian::cartesian;
441 use rand::Rng;
442
443 let (m, n) = shape;
444 let range = cartesian!(0..m, 0..n);
445 let mut rng = rand::rng();
446
447 Self::from_entries(shape, range.filter_map(|(i, j)|
448 if rng.random::<f64>() < density {
449 Some((i, j, R::one()))
450 } else {
451 None
452 }
453 ))
454 }
455}
456
457#[cfg(test)]
458pub(super) mod tests {
459 use itertools::Itertools;
460 use sprs::PermOwned;
461 use yui_core::num::Ratio;
462
463 use super::*;
464
465 #[test]
466 fn init() {
467 let a = SpMat::from_entries((2, 2), [
468 (0, 0, 1),
469 (0, 1, 2),
470 (1, 0, 3),
471 (1, 1, 4)
472 ]);
473 assert_eq!(a.disassemble(), (vec![0, 2, 4], vec![0, 1, 0, 1], vec![1, 3, 2, 4]));
474 }
475
476 #[test]
477 fn init_ratio() {
478 type R = Ratio<i64>;
479 let vals = (0..4).map(|i| R::new(i + 1, 5)).collect_vec();
480 let a = SpMat::from_entries((2, 2), [
481 (0, 0, vals[0].clone()),
482 (0, 1, vals[2].clone()),
483 (1, 0, vals[1].clone()),
484 (1, 1, vals[3].clone())
485 ]);
486 assert_eq!(a.disassemble(), (vec![0, 2, 4], vec![0, 1, 0, 1], vals));
487 }
488
489 #[test]
490 fn from_grid() {
491 let a = SpMat::from_dense_data((2, 2), [1,2,3,4]);
492 assert_eq!(a.disassemble(), (vec![0, 2, 4], vec![0, 1, 0, 1], vec![1, 3, 2, 4]));
493 }
494
495 #[test]
496 fn to_dense() {
497 let a = SpMat::from_entries((2, 2), [
498 (0, 0, 1),
499 (0, 1, 2),
500 (1, 0, 3),
501 (1, 1, 4)
502 ]);
503 assert_eq!(a.into_dense(), Mat::from_data((2, 2), [1,2,3,4]));
504 }
505
506 #[test]
507 fn permute() {
508 let p = PermOwned::new(vec![1,2,3,0]);
509 let q = PermOwned::new(vec![3,0,2,1]);
510 let a = SpMat::from_dense_data((4,4), 0..16);
511 let b = a.permute(p.view(), q.view());
512 assert_eq!(b, SpMat::from_dense_data((4,4), vec![
513 13, 15, 14, 12,
514 1, 3, 2, 0,
515 5, 7, 6, 4,
516 9, 11, 10, 8,
517 ]));
518 }
519
520 #[test]
521 fn submat() {
522 let a = SpMat::from_dense_data((5, 6), 0..30);
523 let b = a.submat(1..3, 2..5);
524 assert_eq!(b, SpMat::from_dense_data((2,3), vec![
525 8, 9, 10,
526 14, 15, 16
527 ]));
528 }
529
530 #[test]
531 fn transpose() {
532 let a = SpMat::from_dense_data((3,4), 0..12);
533 let b = a.transpose();
534
535 assert_eq!(b, SpMat::from_dense_data((4,3), vec![
536 0, 4, 8,
537 1, 5, 9,
538 2, 6, 10,
539 3, 7, 11,
540 ]));
541 }
542
543 #[test]
544 fn extend_cols() {
545 let mut a = SpMat::from_dense_data((4, 3), 0..12);
546 let b = SpMat::from_dense_data((4, 2), 12..20);
547 a.extend_cols(b);
548
549 assert_eq!(a, SpMat::from_dense_data((4,5), vec![
550 0, 1, 2, 12, 13,
551 3, 4, 5, 14, 15,
552 6, 7, 8, 16, 17,
553 9, 10, 11, 18, 19,
554 ]));
555 }
556
557 #[test]
558 fn row_perm() {
559 let a = SpMat::from_dense_data((3, 4), 0..12);
560 let p = PermOwned::new(vec![2,0,1]);
561 let q = SpMat::from_row_perm(p.view());
562 assert!(q * &a == a.permute_rows(p.view()))
563 }
564
565 #[test]
566 fn col_perm() {
567 let a = SpMat::from_dense_data((3, 4), 0..12);
568 let p = PermOwned::new(vec![2,0,1,3]);
569 let q = SpMat::from_col_perm(p.view());
570 assert!(&a * q == a.permute_cols(p.view()))
571 }
572
573 #[test]
574 #[cfg(feature = "serde")]
575 fn serialize() {
576 let a = SpMat::from_dense_data((3, 4), (0..12).map(|x| x % 5));
577 let ser = serde_json::to_string(&a).unwrap();
578 let des = serde_json::from_str(&ser).unwrap();
579 assert_eq!(a, des);
580 }
581}