sciforge_lib/maths/sparse/
csr.rs1use std::fmt;
2use std::ops::{Add, Mul, Neg, Sub};
3
4#[derive(Clone, Debug)]
5pub struct SparseMatrix {
6 pub rows: usize,
7 pub cols: usize,
8 pub row_ptr: Vec<usize>,
9 pub col_idx: Vec<usize>,
10 pub values: Vec<f64>,
11}
12
13impl SparseMatrix {
14 pub fn new(rows: usize, cols: usize) -> Self {
15 Self {
16 rows,
17 cols,
18 row_ptr: vec![0; rows + 1],
19 col_idx: Vec::new(),
20 values: Vec::new(),
21 }
22 }
23
24 pub fn from_triplets(rows: usize, cols: usize, triplets: &mut [(usize, usize, f64)]) -> Self {
25 triplets.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
26 let mut row_ptr = vec![0; rows + 1];
27 let mut col_idx = Vec::with_capacity(triplets.len());
28 let mut values = Vec::with_capacity(triplets.len());
29
30 for &(r, c, v) in triplets.iter() {
31 if v.abs() < 1e-30 {
32 continue;
33 }
34 row_ptr[r + 1] += 1;
35 col_idx.push(c);
36 values.push(v);
37 }
38 for i in 0..rows {
39 row_ptr[i + 1] += row_ptr[i];
40 }
41 Self {
42 rows,
43 cols,
44 row_ptr,
45 col_idx,
46 values,
47 }
48 }
49
50 pub fn identity(n: usize) -> Self {
51 let mut triplets: Vec<(usize, usize, f64)> = (0..n).map(|i| (i, i, 1.0)).collect();
52 Self::from_triplets(n, n, &mut triplets)
53 }
54
55 pub fn nnz(&self) -> usize {
56 self.values.len()
57 }
58
59 pub fn get(&self, row: usize, col: usize) -> f64 {
60 let start = self.row_ptr[row];
61 let end = self.row_ptr[row + 1];
62 for k in start..end {
63 if self.col_idx[k] == col {
64 return self.values[k];
65 }
66 }
67 0.0
68 }
69
70 pub fn mul_vec(&self, x: &[f64]) -> Vec<f64> {
71 assert_eq!(x.len(), self.cols);
72 let mut result = vec![0.0; self.rows];
73 for (i, ri) in result.iter_mut().enumerate() {
74 for k in self.row_ptr[i]..self.row_ptr[i + 1] {
75 *ri += self.values[k] * x[self.col_idx[k]];
76 }
77 }
78 result
79 }
80
81 pub fn transpose(&self) -> Self {
82 let mut triplets: Vec<(usize, usize, f64)> = Vec::with_capacity(self.nnz());
83 for i in 0..self.rows {
84 for k in self.row_ptr[i]..self.row_ptr[i + 1] {
85 triplets.push((self.col_idx[k], i, self.values[k]));
86 }
87 }
88 Self::from_triplets(self.cols, self.rows, &mut triplets)
89 }
90
91 pub fn add(&self, other: &Self) -> Self {
92 assert_eq!(self.rows, other.rows);
93 assert_eq!(self.cols, other.cols);
94 let mut triplets = Vec::new();
95 for i in 0..self.rows {
96 for k in self.row_ptr[i]..self.row_ptr[i + 1] {
97 triplets.push((i, self.col_idx[k], self.values[k]));
98 }
99 for k in other.row_ptr[i]..other.row_ptr[i + 1] {
100 triplets.push((i, other.col_idx[k], other.values[k]));
101 }
102 }
103 Self::from_triplets(self.rows, self.cols, &mut triplets)
104 }
105
106 pub fn scale(&self, s: f64) -> Self {
107 Self {
108 rows: self.rows,
109 cols: self.cols,
110 row_ptr: self.row_ptr.clone(),
111 col_idx: self.col_idx.clone(),
112 values: self.values.iter().map(|v| v * s).collect(),
113 }
114 }
115
116 pub fn matmul(&self, other: &Self) -> Self {
117 assert_eq!(self.cols, other.rows);
118 let mut triplets = Vec::new();
119 let ot = other.transpose();
120 for i in 0..self.rows {
121 for j in 0..other.cols {
122 let mut sum = 0.0;
123 let (mut p, mut q) = (self.row_ptr[i], ot.row_ptr[j]);
124 let (pe, qe) = (self.row_ptr[i + 1], ot.row_ptr[j + 1]);
125 while p < pe && q < qe {
126 if self.col_idx[p] == ot.col_idx[q] {
127 sum += self.values[p] * ot.values[q];
128 p += 1;
129 q += 1;
130 } else if self.col_idx[p] < ot.col_idx[q] {
131 p += 1;
132 } else {
133 q += 1;
134 }
135 }
136 if sum.abs() > 1e-30 {
137 triplets.push((i, j, sum));
138 }
139 }
140 }
141 Self::from_triplets(self.rows, other.cols, &mut triplets)
142 }
143
144 pub fn diagonal(&self) -> Vec<f64> {
145 let n = self.rows.min(self.cols);
146 (0..n).map(|i| self.get(i, i)).collect()
147 }
148
149 pub fn frobenius_norm(&self) -> f64 {
150 self.values.iter().map(|v| v * v).sum::<f64>().sqrt()
151 }
152
153 pub fn to_dense(&self) -> Vec<Vec<f64>> {
154 let mut dense = vec![vec![0.0; self.cols]; self.rows];
155 for (i, dense_row) in dense.iter_mut().enumerate() {
156 for k in self.row_ptr[i]..self.row_ptr[i + 1] {
157 dense_row[self.col_idx[k]] = self.values[k];
158 }
159 }
160 dense
161 }
162}
163
164impl fmt::Display for SparseMatrix {
165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 write!(
167 f,
168 "SparseMatrix({}x{}, nnz={})",
169 self.rows,
170 self.cols,
171 self.nnz()
172 )
173 }
174}
175
176impl Add for &SparseMatrix {
177 type Output = SparseMatrix;
178 fn add(self, rhs: Self) -> SparseMatrix {
179 self.add(rhs)
180 }
181}
182
183impl Sub for &SparseMatrix {
184 type Output = SparseMatrix;
185 fn sub(self, rhs: Self) -> SparseMatrix {
186 self.add(&rhs.scale(-1.0))
187 }
188}
189
190impl Mul for &SparseMatrix {
191 type Output = SparseMatrix;
192 fn mul(self, rhs: Self) -> SparseMatrix {
193 self.matmul(rhs)
194 }
195}
196
197impl Mul<f64> for &SparseMatrix {
198 type Output = SparseMatrix;
199 fn mul(self, s: f64) -> SparseMatrix {
200 self.scale(s)
201 }
202}
203
204impl Neg for &SparseMatrix {
205 type Output = SparseMatrix;
206 fn neg(self) -> SparseMatrix {
207 self.scale(-1.0)
208 }
209}