scirs2_sparse/linalg/
solvers.rs

1//! Direct solvers and basic operations for sparse matrices
2
3use crate::csr::CsrMatrix;
4use crate::error::{SparseError, SparseResult};
5use scirs2_core::numeric::{Float, NumAssign, SparseElement};
6use std::iter::Sum;
7
8// Re-export the functions from the original linalg.rs
9// For now, we'll implement these functions here. In a real migration,
10// we would move the implementations from linalg.rs to here.
11
12// I'll implement stubs for the functions that need to be moved.
13// The actual implementations should be copied from linalg.rs
14
15/// Solve a sparse linear system Ax = b
16#[allow(dead_code)]
17pub fn spsolve<F>(a: &CsrMatrix<F>, b: &[F]) -> SparseResult<Vec<F>>
18where
19    F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
20{
21    // This implementation should be moved from linalg.rs
22    // For now, I'll forward to sparse_direct_solve
23    // For now, use a simple Gaussian elimination approach
24    let a_dense = a.to_dense();
25    gaussian_elimination(&a_dense, b)
26}
27
28/// Solve a sparse linear system using direct methods
29#[allow(dead_code)]
30pub fn sparse_direct_solve<F>(
31    a: &CsrMatrix<F>,
32    b: &[F],
33    _symmetric: bool,
34    _positive_definite: bool,
35) -> SparseResult<Vec<F>>
36where
37    F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
38{
39    if a.rows() != b.len() {
40        return Err(SparseError::DimensionMismatch {
41            expected: a.rows(),
42            found: b.len(),
43        });
44    }
45
46    if a.rows() != a.cols() {
47        return Err(SparseError::ValueError(format!(
48            "Matrix must be square, got {}x{}",
49            a.rows(),
50            a.cols()
51        )));
52    }
53
54    // For this stub implementation, we'll use Gaussian elimination
55    // The real implementation should use optimized sparse solvers
56    let a_dense = a.to_dense();
57    gaussian_elimination(&a_dense, b)
58}
59
60/// Solve a least squares problem
61#[allow(dead_code)]
62pub fn sparse_lstsq<F>(a: &CsrMatrix<F>, b: &[F]) -> SparseResult<Vec<F>>
63where
64    F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
65{
66    // For now, solve normal equations: A^T * A * x = A^T * b
67    let at = a.transpose();
68    let ata = matmul(&at, a)?;
69    // Compute A^T * b
70    let mut atb = vec![F::sparse_zero(); at.rows()];
71    for (row, atb_val) in atb.iter_mut().enumerate().take(at.rows()) {
72        let row_range = at.row_range(row);
73        let row_indices = &at.indices[row_range.clone()];
74        let row_data = &at.data[row_range];
75
76        let mut sum = F::sparse_zero();
77        for (col_idx, &col) in row_indices.iter().enumerate() {
78            sum += row_data[col_idx] * b[col];
79        }
80        *atb_val = sum;
81    }
82    spsolve(&ata, &atb)
83}
84
85/// Compute matrix norm
86#[allow(dead_code)]
87pub fn norm<F>(a: &CsrMatrix<F>, ord: &str) -> SparseResult<F>
88where
89    F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
90{
91    match ord {
92        "1" => {
93            // 1-norm: maximum column sum
94            let mut max_sum = F::sparse_zero();
95            for j in 0..a.cols() {
96                let mut col_sum = F::sparse_zero();
97                for i in 0..a.rows() {
98                    let val = a.get(i, j);
99                    if val != F::sparse_zero() {
100                        col_sum += val.abs();
101                    }
102                }
103                if col_sum > max_sum {
104                    max_sum = col_sum;
105                }
106            }
107            Ok(max_sum)
108        }
109        "inf" => {
110            // Infinity norm: maximum row sum
111            let mut max_sum = F::sparse_zero();
112            for i in 0..a.rows() {
113                let mut row_sum = F::sparse_zero();
114                for j in 0..a.cols() {
115                    let val = a.get(i, j);
116                    if val != F::sparse_zero() {
117                        row_sum += val.abs();
118                    }
119                }
120                if row_sum > max_sum {
121                    max_sum = row_sum;
122                }
123            }
124            Ok(max_sum)
125        }
126        "fro" => {
127            // Frobenius norm: sqrt(sum of squares)
128            let sum_squares: F = a.data.iter().map(|v| *v * *v).sum();
129            Ok(sum_squares.sqrt())
130        }
131        _ => Err(SparseError::ValueError(format!("Unknown norm: {ord}"))),
132    }
133}
134
135/// Matrix multiplication
136#[allow(dead_code)]
137pub fn matmul<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
138where
139    F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
140{
141    // Matrix multiplication - use a simple implementation
142    let mut result_rows = Vec::new();
143    let mut result_cols = Vec::new();
144    let mut result_data = Vec::new();
145
146    for i in 0..a.rows() {
147        for j in 0..b.cols() {
148            let mut sum = F::sparse_zero();
149            for k in 0..a.cols() {
150                sum += a.get(i, k) * b.get(k, j);
151            }
152            if sum != F::sparse_zero() {
153                result_rows.push(i);
154                result_cols.push(j);
155                result_data.push(sum);
156            }
157        }
158    }
159
160    CsrMatrix::new(result_data, result_rows, result_cols, (a.rows(), b.cols()))
161}
162
163/// Matrix addition
164#[allow(dead_code)]
165pub fn add<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
166where
167    F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
168{
169    if a.shape() != b.shape() {
170        return Err(SparseError::ShapeMismatch {
171            expected: a.shape(),
172            found: b.shape(),
173        });
174    }
175
176    // Simple implementation using dense matrices
177    let a_dense = a.to_dense();
178    let b_dense = b.to_dense();
179
180    let mut result_dense = vec![vec![F::sparse_zero(); a.cols()]; a.rows()];
181    for i in 0..a.rows() {
182        for j in 0..a.cols() {
183            result_dense[i][j] = a_dense[i][j] + b_dense[i][j];
184        }
185    }
186
187    // Convert back to CSR
188    let mut rows = Vec::new();
189    let mut cols = Vec::new();
190    let mut data = Vec::new();
191
192    for (i, row) in result_dense.iter().enumerate().take(a.rows()) {
193        for (j, &val) in row.iter().enumerate().take(a.cols()) {
194            if val != F::sparse_zero() {
195                rows.push(i);
196                cols.push(j);
197                data.push(val);
198            }
199        }
200    }
201
202    CsrMatrix::new(data, rows, cols, a.shape())
203}
204
205/// Element-wise multiplication (Hadamard product)
206#[allow(dead_code)]
207pub fn multiply<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
208where
209    F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
210{
211    if a.shape() != b.shape() {
212        return Err(SparseError::ShapeMismatch {
213            expected: a.shape(),
214            found: b.shape(),
215        });
216    }
217
218    let mut rows = Vec::new();
219    let mut cols = Vec::new();
220    let mut data = Vec::new();
221
222    // Only multiply where both matrices have non-zero entries
223    for i in 0..a.rows() {
224        for j in 0..a.cols() {
225            let a_val = a.get(i, j);
226            let b_val = b.get(i, j);
227            if a_val != F::sparse_zero() && b_val != F::sparse_zero() {
228                rows.push(i);
229                cols.push(j);
230                data.push(a_val * b_val);
231            }
232        }
233    }
234
235    CsrMatrix::new(data, rows, cols, a.shape())
236}
237
238/// Create a diagonal matrix
239#[allow(dead_code)]
240pub fn diag_matrix<F>(diag: &[F], n: Option<usize>) -> SparseResult<CsrMatrix<F>>
241where
242    F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
243{
244    let size = n.unwrap_or(diag.len());
245    if size < diag.len() {
246        return Err(SparseError::ValueError(
247            "Size must be at least as large as diagonal".to_string(),
248        ));
249    }
250
251    let mut rows = Vec::new();
252    let mut cols = Vec::new();
253    let mut data = Vec::new();
254
255    for (i, &val) in diag.iter().enumerate() {
256        if val != F::sparse_zero() {
257            rows.push(i);
258            cols.push(i);
259            data.push(val);
260        }
261    }
262
263    CsrMatrix::new(data, rows, cols, (size, size))
264}
265
266/// Create an identity matrix
267#[allow(dead_code)]
268pub fn eye<F>(n: usize) -> SparseResult<CsrMatrix<F>>
269where
270    F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
271{
272    let diag = vec![F::sparse_one(); n];
273    diag_matrix(&diag, Some(n))
274}
275
276/// Matrix inverse
277#[allow(dead_code)]
278pub fn inv<F>(a: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
279where
280    F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
281{
282    if a.rows() != a.cols() {
283        return Err(SparseError::ValueError(
284            "Matrix must be square for inverse".to_string(),
285        ));
286    }
287
288    let n = a.rows();
289
290    // Solve A * X = I for X
291    let mut inv_cols = Vec::new();
292
293    for j in 0..n {
294        // Get column j from identity matrix
295        let mut col_vec = vec![F::sparse_zero(); n];
296        col_vec[j] = F::sparse_one();
297        let x = spsolve(a, &col_vec)?;
298        inv_cols.push(x);
299    }
300
301    // Construct the inverse matrix from columns
302    let mut rows = Vec::new();
303    let mut cols = Vec::new();
304    let mut data = Vec::new();
305
306    for (j, col) in inv_cols.iter().enumerate() {
307        for (i, &val) in col.iter().enumerate() {
308            if val.abs() > F::epsilon() {
309                rows.push(i);
310                cols.push(j);
311                data.push(val);
312            }
313        }
314    }
315
316    CsrMatrix::new(data, rows, cols, (n, n))
317}
318
319// Matrix exponential functionality is now available in linalg/expm.rs module
320
321/// Matrix power
322#[allow(dead_code)]
323pub fn matrix_power<F>(a: &CsrMatrix<F>, power: i32) -> SparseResult<CsrMatrix<F>>
324where
325    F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
326{
327    if a.rows() != a.cols() {
328        return Err(SparseError::ValueError(
329            "Matrix must be square for power".to_string(),
330        ));
331    }
332
333    match power {
334        0 => eye(a.rows()),
335        1 => Ok(a.clone()),
336        p if p > 0 => {
337            let mut result = a.clone();
338            for _ in 1..p {
339                result = matmul(&result, a)?;
340            }
341            Ok(result)
342        }
343        p => {
344            // Negative power: compute inverse and then positive power
345            let inv_a = inv(a)?;
346            matrix_power(&inv_a, -p)
347        }
348    }
349}
350
351// Helper functions
352
353#[allow(dead_code)]
354fn gaussian_elimination<F>(a: &[Vec<F>], b: &[F]) -> SparseResult<Vec<F>>
355where
356    F: Float + NumAssign + SparseElement,
357{
358    let n = a.len();
359    let mut aug = vec![vec![F::sparse_zero(); n + 1]; n];
360
361    // Create augmented matrix
362    for i in 0..n {
363        for j in 0..n {
364            aug[i][j] = a[i][j];
365        }
366        aug[i][n] = b[i];
367    }
368
369    // Forward elimination
370    for k in 0..n {
371        // Find pivot
372        let mut max_idx = k;
373        for i in (k + 1)..n {
374            if aug[i][k].abs() > aug[max_idx][k].abs() {
375                max_idx = i;
376            }
377        }
378        aug.swap(k, max_idx);
379
380        // Check for zero pivot
381        if aug[k][k].abs() < F::epsilon() {
382            return Err(SparseError::SingularMatrix(
383                "Matrix is singular".to_string(),
384            ));
385        }
386
387        // Eliminate column
388        for i in (k + 1)..n {
389            let factor = aug[i][k] / aug[k][k];
390            for j in k..=n {
391                aug[i][j] = aug[i][j] - factor * aug[k][j];
392            }
393        }
394    }
395
396    // Back substitution
397    let mut x = vec![F::sparse_zero(); n];
398    for i in (0..n).rev() {
399        x[i] = aug[i][n];
400        for j in (i + 1)..n {
401            x[i] = x[i] - aug[i][j] * x[j];
402        }
403        x[i] /= aug[i][i];
404    }
405
406    Ok(x)
407}
408
409// Helper functions for matrix exponential have been moved to linalg/expm.rs
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn test_eye_matrix() {
417        let eye_matrix = eye::<f64>(3).unwrap();
418        assert_eq!(eye_matrix.shape(), (3, 3));
419        assert_eq!(eye_matrix.get(0, 0), 1.0);
420        assert_eq!(eye_matrix.get(1, 1), 1.0);
421        assert_eq!(eye_matrix.get(2, 2), 1.0);
422        assert_eq!(eye_matrix.get(0, 1), 0.0);
423    }
424
425    #[test]
426    fn test_diag_matrix() {
427        let diag = vec![2.0, 3.0, 4.0];
428        let diag_matrix = diag_matrix(&diag, None).unwrap();
429        assert_eq!(diag_matrix.shape(), (3, 3));
430        assert_eq!(diag_matrix.get(0, 0), 2.0);
431        assert_eq!(diag_matrix.get(1, 1), 3.0);
432        assert_eq!(diag_matrix.get(2, 2), 4.0);
433    }
434
435    #[test]
436    fn test_matrix_power() {
437        let diag = vec![2.0, 3.0];
438        let matrix = diag_matrix(&diag, None).unwrap();
439
440        // Test power 2
441        let matrix2 = matrix_power(&matrix, 2).unwrap();
442        assert_eq!(matrix2.get(0, 0), 4.0);
443        assert_eq!(matrix2.get(1, 1), 9.0);
444
445        // Test power 0 (identity)
446        let matrix0 = matrix_power(&matrix, 0).unwrap();
447        assert_eq!(matrix0.get(0, 0), 1.0);
448        assert_eq!(matrix0.get(1, 1), 1.0);
449    }
450}