Skip to main content

yui_matrix/sparse/
sp_mat.rs

1use std::ops::{Add, AddAssign, Neg, Sub, SubAssign, Mul, MulAssign, Range};
2use std::iter::zip;
3use std::fmt::{Display, Debug};
4use delegate::delegate;
5use nalgebra_sparse::na::{Scalar, ClosedAddAssign, ClosedSubAssign, ClosedMulAssign};
6use nalgebra_sparse::{CscMatrix, CooMatrix};
7use num_traits::{Zero, One, ToPrimitive};
8use auto_impl_ops::auto_ops;
9use sprs::PermView;
10use yui_core::{Ring, RingOps};
11use crate::dense::*;
12use super::sp_vec::SpVec;
13use super::triang::TriangularType;
14
15#[derive(Clone, PartialEq, Eq)]
16pub struct SpMat<R> { 
17    inner: CscMatrix<R>
18}
19
20impl<R> MatTrait for SpMat<R> {
21    fn shape(&self) -> (usize, usize) {
22        (self.inner.nrows(), self.inner.ncols())
23    }
24}
25
26impl<R> SpMat<R> { 
27    pub(crate) fn inner(&self) -> &CscMatrix<R> { 
28        &self.inner
29    }
30
31    pub(crate) fn into_inner(self) -> CscMatrix<R> { 
32        self.inner
33    }
34
35    pub fn data(&self) -> (&[usize], &[usize], &[R]) { 
36        self.inner.csc_data()
37    }
38
39    pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<R>) { 
40        self.inner.disassemble()
41    }
42
43    pub fn zero(shape: (usize, usize)) -> Self {
44        let csc = CscMatrix::zeros(shape.0, shape.1);
45        Self::from(csc)
46    }
47
48    pub fn is_zero(&self) -> bool
49    where R: Zero {
50        self.inner.values().iter().all(|a| a.is_zero())
51    }
52
53    pub fn id(n: usize) -> Self
54    where R: Scalar + One { 
55        let csc = CscMatrix::identity(n);
56        Self::from(csc)
57    }
58
59    pub fn is_id(&self) -> bool
60    where R: Scalar + One + Zero {
61        self.is_square() && self.iter().all(|(i, j, a)| 
62            (i == j && a.is_one()) || (i != j && a.is_zero())
63        )
64    }
65
66    pub fn is_triang(&self, t: TriangularType) -> bool
67    where R: Zero {
68        if self.nrows() != self.ncols() { 
69            return false
70        }
71
72        if t.is_upper() { 
73            self.iter_nz().all(|(i, j, _)| i <= j )
74        } else { 
75            self.iter_nz().all(|(i, j, _)| i >= j )
76        }
77    }
78    
79    pub fn iter(&self) -> impl Iterator<Item = (usize, usize, &R)> { 
80        self.inner.triplet_iter()
81    }
82
83    pub fn iter_nz(&self) -> impl Iterator<Item = (usize, usize, &R)>
84    where R: Zero { 
85        self.iter().filter(|e| !e.2.is_zero())
86    }
87
88    pub fn into_dense(self) -> Mat<R>
89    where R: Scalar + Zero + ClosedAddAssign { 
90        self.into()
91    }
92
93    pub fn nnz(&self) -> usize { 
94        self.inner.nnz()
95    }
96
97    pub fn density(&self) -> f64 { 
98        let (m, n) = self.shape();
99        if m == 0 || n == 0 { 
100            return 0.0
101        }
102
103        let nnz = self.nnz().to_f64().unwrap();
104        let total = (m * n).to_f64().unwrap();
105
106        nnz / total
107    }
108
109    pub fn redundancy(&self) -> f64
110    where R: Zero { 
111        let nnz = self.nnz().to_f64().unwrap();
112        let red = self.iter().filter(|(_, _, a)| a.is_zero()).count().to_f64().unwrap();
113        red / nnz
114    }
115
116    pub fn mean_weight(&self) -> f64
117    where R: Ring, for<'x> &'x R: RingOps<R> { 
118        let nnz = self.nnz().to_f64().unwrap();
119        let w = self.iter().map(|(_, _, a)| a.c_weight()).sum::<f64>(); 
120        w / nnz
121    }
122}
123
124impl<R> SpMat<R> 
125where R: Scalar + Clone + Zero + ClosedAddAssign { 
126    pub fn from_entries<T>(shape: (usize, usize), entries: T) -> Self
127    where T: IntoIterator<Item = (usize, usize, R)> {
128        let mut coo = CooMatrix::new(shape.0, shape.1);
129        for (i, j, a) in entries { 
130            if a.is_zero() { 
131                continue;
132            }
133            coo.push(i, j, a)
134        }
135        let csc = CscMatrix::from(&coo);
136        Self::from(csc)
137    }
138
139    pub fn from_col_vecs<I>(nrows: usize, vecs: I) -> Self 
140    where I: IntoIterator<Item = SpVec<R>> { 
141        let mut col_offsets = vec![0];
142        let mut row_indices = vec![];
143        let mut values = vec![];
144
145        for v in vecs.into_iter() { 
146            assert_eq!(nrows, v.dim());
147            let (_, mut v_rows, mut v_values) = v.into_inner().disassemble();
148
149            row_indices.append(&mut v_rows);
150            values.append(&mut v_values);
151            col_offsets.push(row_indices.len());
152        }
153
154        let ncols = col_offsets.len() - 1;
155        let csc = CscMatrix::try_from_csc_data(nrows, ncols, col_offsets, row_indices, values).unwrap();
156        Self::from(csc)
157    }
158
159    pub fn from_dense_data<I>(shape: (usize, usize), data: I) -> Self
160    where I: IntoIterator<Item = R> { 
161        let n = shape.1;
162        Self::from_entries(
163            shape, 
164            data.into_iter().enumerate().map(|(k, a)| { 
165                let (i, j) = (k / n, k % n);
166                (i, j, a)
167            })
168        )
169    }
170
171    pub fn col_vec(&self, j: usize) -> SpVec<R>
172    where R: Scalar + Zero + ClosedAddAssign { 
173        let col = self.inner.col(j);
174        let iter = Iterator::zip(
175            col.row_indices().iter().cloned(), 
176            col.values().iter().cloned()
177        );
178        SpVec::from_entries(self.nrows(), iter)
179    }
180
181    pub fn transpose(&self) -> Self { 
182        self.inner.transpose().into()
183    }
184
185    pub fn extract<F>(&self, shape: (usize, usize), f: F) -> SpMat<R>
186    where F: Fn(usize, usize) -> Option<(usize, usize)> { 
187        SpMat::from_entries(shape, self.iter().filter_map(|(i, j, a)|
188            f(i, j).map(|(i, j)| (i, j, a.clone()))
189        ))
190    }
191
192    pub fn permute(&self, p: PermView, q: PermView) -> SpMat<R> { 
193        self.extract(self.shape(), |i, j| Some((p.at(i), q.at(j))))
194    }
195
196    pub fn permute_rows(&self, p: PermView) -> SpMat<R> { 
197        let id = PermView::identity(self.ncols());
198        self.permute(p, id)
199    }
200    
201    pub fn permute_cols(&self, q: PermView) -> SpMat<R> { 
202        let id = PermView::identity(self.nrows());
203        self.permute(id, q)
204    }
205
206    pub fn submat(&self, rows: Range<usize>, cols: Range<usize>) -> SpMat<R> { 
207        let (i0, i1) = (rows.start, rows.end);
208        let (j0, j1) = (cols.start, cols.end);
209
210        assert!(i0 <= i1 && i1 <= self.nrows());
211        assert!(j0 <= j1 && j1 <= self.ncols());
212
213        let shape = (i1 - i0, j1 - j0);
214        self.extract(shape, |i, j|
215            (rows.contains(&i) && cols.contains(&j)).then( ||
216                (i - i0, j - j0)
217            )
218        )
219    }
220
221    pub fn submat_rows(&self, rows: Range<usize>) -> SpMat<R> { 
222        let n = self.ncols();
223        self.submat(rows, 0 .. n)
224    }
225
226    pub fn submat_cols(&self, cols: Range<usize>) -> SpMat<R> { 
227        let m = self.nrows();
228        self.submat(0 .. m, cols)
229    }
230
231    pub fn divide4(&self, point: (usize, usize)) -> [SpMat<R>; 4] { 
232        let (m, n) = self.shape();
233        let (k, l) = point;
234        assert!(k <= m);
235        assert!(l <= n);
236
237        let mut a = CooMatrix::new(k, l);
238        let mut b = CooMatrix::new(k, n - l);
239        let mut c = CooMatrix::new(m - k, l);
240        let mut d = CooMatrix::new(m - k, n - l);
241        
242        for (i, j, r) in self.iter() { 
243            if r.is_zero() { continue }
244            let r = r.clone();
245            match ((0..k).contains(&i), (0..l).contains(&j)) { 
246                (true , true ) => a.push(i, j, r),
247                (true , false) => b.push(i, j - l, r),
248                (false, true ) => c.push(i - k, j, r),
249                (false, false) => d.push(i - k, j - l, r),
250            }
251        }
252        
253        [a, b, c, d].map(|x| 
254            CscMatrix::from(&x).into()
255        )
256    }
257
258    pub fn combine_blocks(blocks: [&SpMat<R>; 4]) -> SpMat<R> {
259        let [a, b, c, d] = blocks;
260
261        assert_eq!(a.nrows(), b.nrows());
262        assert_eq!(c.nrows(), d.nrows());
263        assert_eq!(a.ncols(), c.ncols());
264        assert_eq!(b.ncols(), d.ncols());
265
266        let (m, n) = (a.nrows() + c.nrows(), a.ncols() + b.ncols());
267        let (k, l) = a.shape();
268
269        let entries = zip(
270            [a, b, c, d], 
271            [(0,0), (0,l), (k,0), (k,l)]
272        ).flat_map(|(x, (di, dj))| 
273            x.iter().map(move |(i, j, r)|
274                (i + di, j + dj, r.clone())
275            )
276        );
277
278        Self::from_entries((m, n), entries)
279    }
280
281    pub fn concat(&self, b: &Self) -> Self { 
282        let zero = |m, n| SpMat::<R>::zero((m, n));
283        Self::combine_blocks([
284            self, 
285            b, 
286            &zero(0, self.ncols()), 
287            &zero(0, b.ncols())
288        ])
289    }
290
291    pub fn stack(&self, b: &Self) -> Self { 
292        let zero = |m, n| SpMat::<R>::zero((m, n));
293        Self::combine_blocks([
294            self, 
295            &zero(self.nrows(), 0), 
296            b, 
297            &zero(b.nrows(), 0)
298        ])
299    }
300
301    pub fn extend_cols(&mut self, b: Self) { 
302        assert_eq!(self.nrows(), b.nrows());
303
304        if b.ncols() == 0 { 
305            return
306        }
307
308        let shape = (self.nrows(), self.ncols() + b.ncols());
309        let l = std::mem::replace(&mut self.inner, CscMatrix::zeros(0, 0));
310        let r = b.inner;
311
312        let (mut col_offsets, mut row_indices, mut values) = l.disassemble();
313        let (c, mut r, mut v) = r.disassemble();
314        
315        let offset = col_offsets.pop().unwrap(); // pop last element.
316        col_offsets.extend(c.into_iter().map(|i| offset + i));
317        row_indices.append(&mut r);
318        values.append(&mut v);
319
320        self.inner = CscMatrix::try_from_csc_data(
321            shape.0, shape.1, 
322            col_offsets, 
323            row_indices, 
324            values
325        ).unwrap();
326    }
327
328    // row_perm(p) * a == a.permute_rows(p)
329    pub fn from_row_perm(p: PermView) -> Self
330    where R: One {
331        let n = p.dim();
332        Self::from_entries((n, n), (0..n).map(|i|
333            (p.at(i), i, R::one())
334        ))
335    }
336
337    // a * col_perm(p) == a.permute_cols(p)
338    pub fn from_col_perm(p: PermView) -> Self
339    where R: One {
340        let n = p.dim();
341        Self::from_entries((n, n), (0..n).map(|i|
342            (i, p.at(i), R::one())
343        ))
344    }
345}
346
347impl<R> From<CscMatrix<R>> for SpMat<R> {
348    fn from(inner: CscMatrix<R>) -> Self {
349        Self { inner }
350    }
351}
352
353impl<R> From<Mat<R>> for SpMat<R>
354where R: Scalar + Zero {
355    fn from(value: Mat<R>) -> Self {
356        let csc = CscMatrix::from(value.inner());
357        Self::from(csc)
358    }
359}
360
361impl<R> Default for SpMat<R> {
362    fn default() -> Self {
363        Self::zero((0, 0))
364    }
365}
366
367impl<R> Neg for SpMat<R>
368where R: Scalar + Neg<Output = R> {
369    type Output = Self;
370    fn neg(self) -> Self::Output {
371        Self::from(-self.inner)
372    }
373}
374
375impl<R> Neg for &SpMat<R>
376where R: Scalar + Neg<Output = R> {
377    type Output = SpMat<R>;
378    fn neg(self) -> Self::Output {
379        SpMat::from(-&self.inner)
380    }
381}
382
383// see: nalgebra_sparse::ops::impl_std_ops.
384macro_rules! impl_binop {
385    ($trait:ident, $method:ident) => {
386        #[auto_ops]
387        impl<'a, 'b, R> $trait<&'b SpMat<R>> for &'a SpMat<R>
388        where R: Scalar + ClosedAddAssign + ClosedSubAssign + ClosedMulAssign + Zero + One + Neg<Output = R> {
389            type Output = SpMat<R>;
390            fn $method(self, rhs: &'b SpMat<R>) -> Self::Output {
391                let res = (&self.inner).$method(&rhs.inner);
392                SpMat::from(res)
393            }
394        }
395    };
396}
397
398impl_binop!(Add, add);
399impl_binop!(Sub, sub);
400impl_binop!(Mul, mul);
401
402impl<R> Display for SpMat<R>
403where R: Display + Debug {
404    delegate! { to self.inner { 
405        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
406    }}
407}
408
409impl<R> Debug for SpMat<R>
410where R: Display + Debug {
411    delegate! { to self.inner { 
412        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
413    }}
414}
415
416#[cfg(feature = "serde")]
417impl<R> serde::Serialize for SpMat<R>
418where R: Clone + serde::Serialize {
419    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
420    where S: serde::Serializer {
421        self.inner.serialize(serializer)
422    }
423}
424
425#[cfg(feature = "serde")]
426impl<'de, R> serde::Deserialize<'de> for SpMat<R>
427where R: Clone + serde::Deserialize<'de> {
428    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
429    where D: serde::Deserializer<'de> {
430        let inner = CscMatrix::deserialize(deserializer)?;
431        let res = Self::from(inner);
432        Ok(res)
433    }
434}
435
436#[cfg(test)]
437impl<R> SpMat<R>
438where R: Scalar + Zero + One + ClosedAddAssign { 
439    pub fn rand(shape: (usize, usize), density: f64) -> Self {
440        use cartesian::cartesian;
441        use rand::Rng;
442    
443        let (m, n) = shape;
444        let range = cartesian!(0..m, 0..n);
445        let mut rng = rand::rng();
446    
447        Self::from_entries(shape, range.filter_map(|(i, j)|
448            if rng.random::<f64>() < density { 
449                Some((i, j, R::one()))
450            } else { 
451                None
452            }
453        ))
454    }
455}
456
457#[cfg(test)]
458pub(super) mod tests { 
459    use itertools::Itertools;
460    use sprs::PermOwned;
461    use yui_core::num::Ratio;
462
463    use super::*;
464
465    #[test]
466    fn init() { 
467        let a = SpMat::from_entries((2, 2), [
468            (0, 0, 1),
469            (0, 1, 2),
470            (1, 0, 3),
471            (1, 1, 4)
472        ]);
473        assert_eq!(a.disassemble(), (vec![0, 2, 4], vec![0, 1, 0, 1], vec![1, 3, 2, 4]));
474    }
475
476    #[test]
477    fn init_ratio() { 
478        type R = Ratio<i64>;
479        let vals = (0..4).map(|i| R::new(i + 1, 5)).collect_vec();
480        let a = SpMat::from_entries((2, 2), [
481            (0, 0, vals[0].clone()),
482            (0, 1, vals[2].clone()),
483            (1, 0, vals[1].clone()),
484            (1, 1, vals[3].clone())
485        ]);
486        assert_eq!(a.disassemble(), (vec![0, 2, 4], vec![0, 1, 0, 1], vals));
487    }
488
489    #[test]
490    fn from_grid() { 
491        let a = SpMat::from_dense_data((2, 2), [1,2,3,4]);
492        assert_eq!(a.disassemble(), (vec![0, 2, 4], vec![0, 1, 0, 1], vec![1, 3, 2, 4]));
493    }
494
495    #[test]
496    fn to_dense() { 
497        let a = SpMat::from_entries((2, 2), [
498            (0, 0, 1),
499            (0, 1, 2),
500            (1, 0, 3),
501            (1, 1, 4)
502        ]);
503        assert_eq!(a.into_dense(), Mat::from_data((2, 2), [1,2,3,4]));
504    }
505
506    #[test]
507    fn permute() { 
508        let p = PermOwned::new(vec![1,2,3,0]);
509        let q = PermOwned::new(vec![3,0,2,1]);
510        let a = SpMat::from_dense_data((4,4), 0..16);
511        let b = a.permute(p.view(), q.view());
512        assert_eq!(b, SpMat::from_dense_data((4,4), vec![
513            13, 15, 14, 12,
514             1,  3,  2,  0,
515             5,  7,  6,  4,
516             9, 11, 10,  8,
517        ]));
518    }
519
520    #[test]
521    fn submat() { 
522        let a = SpMat::from_dense_data((5, 6), 0..30);
523        let b = a.submat(1..3, 2..5);
524        assert_eq!(b, SpMat::from_dense_data((2,3), vec![
525             8,  9, 10,
526            14, 15, 16
527        ]));
528    }
529
530    #[test]
531    fn transpose() { 
532        let a = SpMat::from_dense_data((3,4), 0..12);
533        let b = a.transpose();
534
535        assert_eq!(b, SpMat::from_dense_data((4,3), vec![
536            0, 4, 8, 
537            1, 5, 9, 
538            2, 6, 10, 
539            3, 7, 11, 
540        ]));
541    }
542
543    #[test]
544    fn extend_cols() {
545        let mut a = SpMat::from_dense_data((4, 3), 0..12);
546        let b = SpMat::from_dense_data((4, 2), 12..20);
547        a.extend_cols(b);
548
549        assert_eq!(a, SpMat::from_dense_data((4,5), vec![
550            0,  1,  2, 12, 13,
551            3,  4,  5, 14, 15,
552            6,  7,  8, 16, 17,
553            9, 10, 11, 18, 19,
554        ]));
555    }
556
557    #[test]
558    fn row_perm() {
559        let a = SpMat::from_dense_data((3, 4), 0..12);
560        let p = PermOwned::new(vec![2,0,1]);
561        let q = SpMat::from_row_perm(p.view());
562        assert!(q * &a == a.permute_rows(p.view()))
563    }
564
565    #[test]
566    fn col_perm() {
567        let a = SpMat::from_dense_data((3, 4), 0..12);
568        let p = PermOwned::new(vec![2,0,1,3]);
569        let q = SpMat::from_col_perm(p.view());
570        assert!(&a * q == a.permute_cols(p.view()))
571    }
572
573    #[test]
574    #[cfg(feature = "serde")]
575    fn serialize() { 
576        let a = SpMat::from_dense_data((3, 4), (0..12).map(|x| x % 5));
577        let ser = serde_json::to_string(&a).unwrap();
578        let des = serde_json::from_str(&ser).unwrap();
579        assert_eq!(a, des);
580    }
581}