Skip to main content

sciforge_lib/maths/sparse/
csr.rs

1use std::fmt;
2use std::ops::{Add, Mul, Neg, Sub};
3
4#[derive(Clone, Debug)]
5pub struct SparseMatrix {
6    pub rows: usize,
7    pub cols: usize,
8    pub row_ptr: Vec<usize>,
9    pub col_idx: Vec<usize>,
10    pub values: Vec<f64>,
11}
12
13impl SparseMatrix {
14    pub fn new(rows: usize, cols: usize) -> Self {
15        Self {
16            rows,
17            cols,
18            row_ptr: vec![0; rows + 1],
19            col_idx: Vec::new(),
20            values: Vec::new(),
21        }
22    }
23
24    pub fn from_triplets(rows: usize, cols: usize, triplets: &mut [(usize, usize, f64)]) -> Self {
25        triplets.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
26        let mut row_ptr = vec![0; rows + 1];
27        let mut col_idx = Vec::with_capacity(triplets.len());
28        let mut values = Vec::with_capacity(triplets.len());
29
30        for &(r, c, v) in triplets.iter() {
31            if v.abs() < 1e-30 {
32                continue;
33            }
34            row_ptr[r + 1] += 1;
35            col_idx.push(c);
36            values.push(v);
37        }
38        for i in 0..rows {
39            row_ptr[i + 1] += row_ptr[i];
40        }
41        Self {
42            rows,
43            cols,
44            row_ptr,
45            col_idx,
46            values,
47        }
48    }
49
50    pub fn identity(n: usize) -> Self {
51        let mut triplets: Vec<(usize, usize, f64)> = (0..n).map(|i| (i, i, 1.0)).collect();
52        Self::from_triplets(n, n, &mut triplets)
53    }
54
55    pub fn nnz(&self) -> usize {
56        self.values.len()
57    }
58
59    pub fn get(&self, row: usize, col: usize) -> f64 {
60        let start = self.row_ptr[row];
61        let end = self.row_ptr[row + 1];
62        for k in start..end {
63            if self.col_idx[k] == col {
64                return self.values[k];
65            }
66        }
67        0.0
68    }
69
70    pub fn mul_vec(&self, x: &[f64]) -> Vec<f64> {
71        assert_eq!(x.len(), self.cols);
72        let mut result = vec![0.0; self.rows];
73        for (i, ri) in result.iter_mut().enumerate() {
74            for k in self.row_ptr[i]..self.row_ptr[i + 1] {
75                *ri += self.values[k] * x[self.col_idx[k]];
76            }
77        }
78        result
79    }
80
81    pub fn transpose(&self) -> Self {
82        let mut triplets: Vec<(usize, usize, f64)> = Vec::with_capacity(self.nnz());
83        for i in 0..self.rows {
84            for k in self.row_ptr[i]..self.row_ptr[i + 1] {
85                triplets.push((self.col_idx[k], i, self.values[k]));
86            }
87        }
88        Self::from_triplets(self.cols, self.rows, &mut triplets)
89    }
90
91    pub fn add(&self, other: &Self) -> Self {
92        assert_eq!(self.rows, other.rows);
93        assert_eq!(self.cols, other.cols);
94        let mut triplets = Vec::new();
95        for i in 0..self.rows {
96            for k in self.row_ptr[i]..self.row_ptr[i + 1] {
97                triplets.push((i, self.col_idx[k], self.values[k]));
98            }
99            for k in other.row_ptr[i]..other.row_ptr[i + 1] {
100                triplets.push((i, other.col_idx[k], other.values[k]));
101            }
102        }
103        Self::from_triplets(self.rows, self.cols, &mut triplets)
104    }
105
106    pub fn scale(&self, s: f64) -> Self {
107        Self {
108            rows: self.rows,
109            cols: self.cols,
110            row_ptr: self.row_ptr.clone(),
111            col_idx: self.col_idx.clone(),
112            values: self.values.iter().map(|v| v * s).collect(),
113        }
114    }
115
116    pub fn matmul(&self, other: &Self) -> Self {
117        assert_eq!(self.cols, other.rows);
118        let mut triplets = Vec::new();
119        let ot = other.transpose();
120        for i in 0..self.rows {
121            for j in 0..other.cols {
122                let mut sum = 0.0;
123                let (mut p, mut q) = (self.row_ptr[i], ot.row_ptr[j]);
124                let (pe, qe) = (self.row_ptr[i + 1], ot.row_ptr[j + 1]);
125                while p < pe && q < qe {
126                    if self.col_idx[p] == ot.col_idx[q] {
127                        sum += self.values[p] * ot.values[q];
128                        p += 1;
129                        q += 1;
130                    } else if self.col_idx[p] < ot.col_idx[q] {
131                        p += 1;
132                    } else {
133                        q += 1;
134                    }
135                }
136                if sum.abs() > 1e-30 {
137                    triplets.push((i, j, sum));
138                }
139            }
140        }
141        Self::from_triplets(self.rows, other.cols, &mut triplets)
142    }
143
144    pub fn diagonal(&self) -> Vec<f64> {
145        let n = self.rows.min(self.cols);
146        (0..n).map(|i| self.get(i, i)).collect()
147    }
148
149    pub fn frobenius_norm(&self) -> f64 {
150        self.values.iter().map(|v| v * v).sum::<f64>().sqrt()
151    }
152
153    pub fn to_dense(&self) -> Vec<Vec<f64>> {
154        let mut dense = vec![vec![0.0; self.cols]; self.rows];
155        for (i, dense_row) in dense.iter_mut().enumerate() {
156            for k in self.row_ptr[i]..self.row_ptr[i + 1] {
157                dense_row[self.col_idx[k]] = self.values[k];
158            }
159        }
160        dense
161    }
162}
163
164impl fmt::Display for SparseMatrix {
165    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166        write!(
167            f,
168            "SparseMatrix({}x{}, nnz={})",
169            self.rows,
170            self.cols,
171            self.nnz()
172        )
173    }
174}
175
176impl Add for &SparseMatrix {
177    type Output = SparseMatrix;
178    fn add(self, rhs: Self) -> SparseMatrix {
179        self.add(rhs)
180    }
181}
182
183impl Sub for &SparseMatrix {
184    type Output = SparseMatrix;
185    fn sub(self, rhs: Self) -> SparseMatrix {
186        self.add(&rhs.scale(-1.0))
187    }
188}
189
190impl Mul for &SparseMatrix {
191    type Output = SparseMatrix;
192    fn mul(self, rhs: Self) -> SparseMatrix {
193        self.matmul(rhs)
194    }
195}
196
197impl Mul<f64> for &SparseMatrix {
198    type Output = SparseMatrix;
199    fn mul(self, s: f64) -> SparseMatrix {
200        self.scale(s)
201    }
202}
203
204impl Neg for &SparseMatrix {
205    type Output = SparseMatrix;
206    fn neg(self) -> SparseMatrix {
207        self.scale(-1.0)
208    }
209}