1use crate::csr::CsrMatrix;
4use crate::error::{SparseError, SparseResult};
5use scirs2_core::numeric::{Float, NumAssign, SparseElement};
6use std::iter::Sum;
7
8#[allow(dead_code)]
17pub fn spsolve<F>(a: &CsrMatrix<F>, b: &[F]) -> SparseResult<Vec<F>>
18where
19 F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
20{
21 let a_dense = a.to_dense();
25 gaussian_elimination(&a_dense, b)
26}
27
28#[allow(dead_code)]
30pub fn sparse_direct_solve<F>(
31 a: &CsrMatrix<F>,
32 b: &[F],
33 _symmetric: bool,
34 _positive_definite: bool,
35) -> SparseResult<Vec<F>>
36where
37 F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
38{
39 if a.rows() != b.len() {
40 return Err(SparseError::DimensionMismatch {
41 expected: a.rows(),
42 found: b.len(),
43 });
44 }
45
46 if a.rows() != a.cols() {
47 return Err(SparseError::ValueError(format!(
48 "Matrix must be square, got {}x{}",
49 a.rows(),
50 a.cols()
51 )));
52 }
53
54 let a_dense = a.to_dense();
57 gaussian_elimination(&a_dense, b)
58}
59
60#[allow(dead_code)]
62pub fn sparse_lstsq<F>(a: &CsrMatrix<F>, b: &[F]) -> SparseResult<Vec<F>>
63where
64 F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
65{
66 let at = a.transpose();
68 let ata = matmul(&at, a)?;
69 let mut atb = vec![F::sparse_zero(); at.rows()];
71 for (row, atb_val) in atb.iter_mut().enumerate().take(at.rows()) {
72 let row_range = at.row_range(row);
73 let row_indices = &at.indices[row_range.clone()];
74 let row_data = &at.data[row_range];
75
76 let mut sum = F::sparse_zero();
77 for (col_idx, &col) in row_indices.iter().enumerate() {
78 sum += row_data[col_idx] * b[col];
79 }
80 *atb_val = sum;
81 }
82 spsolve(&ata, &atb)
83}
84
85#[allow(dead_code)]
87pub fn norm<F>(a: &CsrMatrix<F>, ord: &str) -> SparseResult<F>
88where
89 F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
90{
91 match ord {
92 "1" => {
93 let mut max_sum = F::sparse_zero();
95 for j in 0..a.cols() {
96 let mut col_sum = F::sparse_zero();
97 for i in 0..a.rows() {
98 let val = a.get(i, j);
99 if val != F::sparse_zero() {
100 col_sum += val.abs();
101 }
102 }
103 if col_sum > max_sum {
104 max_sum = col_sum;
105 }
106 }
107 Ok(max_sum)
108 }
109 "inf" => {
110 let mut max_sum = F::sparse_zero();
112 for i in 0..a.rows() {
113 let mut row_sum = F::sparse_zero();
114 for j in 0..a.cols() {
115 let val = a.get(i, j);
116 if val != F::sparse_zero() {
117 row_sum += val.abs();
118 }
119 }
120 if row_sum > max_sum {
121 max_sum = row_sum;
122 }
123 }
124 Ok(max_sum)
125 }
126 "fro" => {
127 let sum_squares: F = a.data.iter().map(|v| *v * *v).sum();
129 Ok(sum_squares.sqrt())
130 }
131 _ => Err(SparseError::ValueError(format!("Unknown norm: {ord}"))),
132 }
133}
134
135#[allow(dead_code)]
137pub fn matmul<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
138where
139 F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
140{
141 let mut result_rows = Vec::new();
143 let mut result_cols = Vec::new();
144 let mut result_data = Vec::new();
145
146 for i in 0..a.rows() {
147 for j in 0..b.cols() {
148 let mut sum = F::sparse_zero();
149 for k in 0..a.cols() {
150 sum += a.get(i, k) * b.get(k, j);
151 }
152 if sum != F::sparse_zero() {
153 result_rows.push(i);
154 result_cols.push(j);
155 result_data.push(sum);
156 }
157 }
158 }
159
160 CsrMatrix::new(result_data, result_rows, result_cols, (a.rows(), b.cols()))
161}
162
163#[allow(dead_code)]
165pub fn add<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
166where
167 F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
168{
169 if a.shape() != b.shape() {
170 return Err(SparseError::ShapeMismatch {
171 expected: a.shape(),
172 found: b.shape(),
173 });
174 }
175
176 let a_dense = a.to_dense();
178 let b_dense = b.to_dense();
179
180 let mut result_dense = vec![vec![F::sparse_zero(); a.cols()]; a.rows()];
181 for i in 0..a.rows() {
182 for j in 0..a.cols() {
183 result_dense[i][j] = a_dense[i][j] + b_dense[i][j];
184 }
185 }
186
187 let mut rows = Vec::new();
189 let mut cols = Vec::new();
190 let mut data = Vec::new();
191
192 for (i, row) in result_dense.iter().enumerate().take(a.rows()) {
193 for (j, &val) in row.iter().enumerate().take(a.cols()) {
194 if val != F::sparse_zero() {
195 rows.push(i);
196 cols.push(j);
197 data.push(val);
198 }
199 }
200 }
201
202 CsrMatrix::new(data, rows, cols, a.shape())
203}
204
205#[allow(dead_code)]
207pub fn multiply<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
208where
209 F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
210{
211 if a.shape() != b.shape() {
212 return Err(SparseError::ShapeMismatch {
213 expected: a.shape(),
214 found: b.shape(),
215 });
216 }
217
218 let mut rows = Vec::new();
219 let mut cols = Vec::new();
220 let mut data = Vec::new();
221
222 for i in 0..a.rows() {
224 for j in 0..a.cols() {
225 let a_val = a.get(i, j);
226 let b_val = b.get(i, j);
227 if a_val != F::sparse_zero() && b_val != F::sparse_zero() {
228 rows.push(i);
229 cols.push(j);
230 data.push(a_val * b_val);
231 }
232 }
233 }
234
235 CsrMatrix::new(data, rows, cols, a.shape())
236}
237
238#[allow(dead_code)]
240pub fn diag_matrix<F>(diag: &[F], n: Option<usize>) -> SparseResult<CsrMatrix<F>>
241where
242 F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
243{
244 let size = n.unwrap_or(diag.len());
245 if size < diag.len() {
246 return Err(SparseError::ValueError(
247 "Size must be at least as large as diagonal".to_string(),
248 ));
249 }
250
251 let mut rows = Vec::new();
252 let mut cols = Vec::new();
253 let mut data = Vec::new();
254
255 for (i, &val) in diag.iter().enumerate() {
256 if val != F::sparse_zero() {
257 rows.push(i);
258 cols.push(i);
259 data.push(val);
260 }
261 }
262
263 CsrMatrix::new(data, rows, cols, (size, size))
264}
265
266#[allow(dead_code)]
268pub fn eye<F>(n: usize) -> SparseResult<CsrMatrix<F>>
269where
270 F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
271{
272 let diag = vec![F::sparse_one(); n];
273 diag_matrix(&diag, Some(n))
274}
275
276#[allow(dead_code)]
278pub fn inv<F>(a: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
279where
280 F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
281{
282 if a.rows() != a.cols() {
283 return Err(SparseError::ValueError(
284 "Matrix must be square for inverse".to_string(),
285 ));
286 }
287
288 let n = a.rows();
289
290 let mut inv_cols = Vec::new();
292
293 for j in 0..n {
294 let mut col_vec = vec![F::sparse_zero(); n];
296 col_vec[j] = F::sparse_one();
297 let x = spsolve(a, &col_vec)?;
298 inv_cols.push(x);
299 }
300
301 let mut rows = Vec::new();
303 let mut cols = Vec::new();
304 let mut data = Vec::new();
305
306 for (j, col) in inv_cols.iter().enumerate() {
307 for (i, &val) in col.iter().enumerate() {
308 if val.abs() > F::epsilon() {
309 rows.push(i);
310 cols.push(j);
311 data.push(val);
312 }
313 }
314 }
315
316 CsrMatrix::new(data, rows, cols, (n, n))
317}
318
319#[allow(dead_code)]
323pub fn matrix_power<F>(a: &CsrMatrix<F>, power: i32) -> SparseResult<CsrMatrix<F>>
324where
325 F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
326{
327 if a.rows() != a.cols() {
328 return Err(SparseError::ValueError(
329 "Matrix must be square for power".to_string(),
330 ));
331 }
332
333 match power {
334 0 => eye(a.rows()),
335 1 => Ok(a.clone()),
336 p if p > 0 => {
337 let mut result = a.clone();
338 for _ in 1..p {
339 result = matmul(&result, a)?;
340 }
341 Ok(result)
342 }
343 p => {
344 let inv_a = inv(a)?;
346 matrix_power(&inv_a, -p)
347 }
348 }
349}
350
351#[allow(dead_code)]
354fn gaussian_elimination<F>(a: &[Vec<F>], b: &[F]) -> SparseResult<Vec<F>>
355where
356 F: Float + NumAssign + SparseElement,
357{
358 let n = a.len();
359 let mut aug = vec![vec![F::sparse_zero(); n + 1]; n];
360
361 for i in 0..n {
363 for j in 0..n {
364 aug[i][j] = a[i][j];
365 }
366 aug[i][n] = b[i];
367 }
368
369 for k in 0..n {
371 let mut max_idx = k;
373 for i in (k + 1)..n {
374 if aug[i][k].abs() > aug[max_idx][k].abs() {
375 max_idx = i;
376 }
377 }
378 aug.swap(k, max_idx);
379
380 if aug[k][k].abs() < F::epsilon() {
382 return Err(SparseError::SingularMatrix(
383 "Matrix is singular".to_string(),
384 ));
385 }
386
387 for i in (k + 1)..n {
389 let factor = aug[i][k] / aug[k][k];
390 for j in k..=n {
391 aug[i][j] = aug[i][j] - factor * aug[k][j];
392 }
393 }
394 }
395
396 let mut x = vec![F::sparse_zero(); n];
398 for i in (0..n).rev() {
399 x[i] = aug[i][n];
400 for j in (i + 1)..n {
401 x[i] = x[i] - aug[i][j] * x[j];
402 }
403 x[i] /= aug[i][i];
404 }
405
406 Ok(x)
407}
408
409#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn test_eye_matrix() {
417 let eye_matrix = eye::<f64>(3).unwrap();
418 assert_eq!(eye_matrix.shape(), (3, 3));
419 assert_eq!(eye_matrix.get(0, 0), 1.0);
420 assert_eq!(eye_matrix.get(1, 1), 1.0);
421 assert_eq!(eye_matrix.get(2, 2), 1.0);
422 assert_eq!(eye_matrix.get(0, 1), 0.0);
423 }
424
425 #[test]
426 fn test_diag_matrix() {
427 let diag = vec![2.0, 3.0, 4.0];
428 let diag_matrix = diag_matrix(&diag, None).unwrap();
429 assert_eq!(diag_matrix.shape(), (3, 3));
430 assert_eq!(diag_matrix.get(0, 0), 2.0);
431 assert_eq!(diag_matrix.get(1, 1), 3.0);
432 assert_eq!(diag_matrix.get(2, 2), 4.0);
433 }
434
435 #[test]
436 fn test_matrix_power() {
437 let diag = vec![2.0, 3.0];
438 let matrix = diag_matrix(&diag, None).unwrap();
439
440 let matrix2 = matrix_power(&matrix, 2).unwrap();
442 assert_eq!(matrix2.get(0, 0), 4.0);
443 assert_eq!(matrix2.get(1, 1), 9.0);
444
445 let matrix0 = matrix_power(&matrix, 0).unwrap();
447 assert_eq!(matrix0.get(0, 0), 1.0);
448 assert_eq!(matrix0.get(1, 1), 1.0);
449 }
450}