scirs2_sparse/linalg/
spai.rs

1//! Sparse Approximate Inverse (SPAI) preconditioner
2
3use crate::csr::CsrMatrix;
4use crate::error::{SparseError, SparseResult};
5use crate::linalg::interface::LinearOperator;
6use scirs2_core::numeric::{Float, NumAssign, SparseElement};
7use std::fmt::Debug;
8use std::iter::Sum;
9
10/// Sparse Approximate Inverse (SPAI) preconditioner
11///
12/// This preconditioner computes a sparse approximate inverse M ≈ A^(-1)
13/// using a minimization approach: minimize ||I - AM||_F subject to
14/// sparsity constraints on M.
15pub struct SpaiPreconditioner<F> {
16    /// Sparse approximate inverse of the original matrix
17    approx_inverse: CsrMatrix<F>,
18}
19
20/// Options for the SPAI preconditioner
21pub struct SpaiOptions {
22    /// Maximum number of nonzeros per column of M
23    pub max_nnz_per_col: usize,
24    /// Tolerance for least squares solver
25    pub ls_tolerance: f64,
26    /// Maximum iterations for least squares solver
27    pub max_ls_iters: usize,
28}
29
30impl Default for SpaiOptions {
31    fn default() -> Self {
32        Self {
33            max_nnz_per_col: 10,
34            ls_tolerance: 1e-10,
35            max_ls_iters: 100,
36        }
37    }
38}
39
40impl<F: Float + NumAssign + Sum + Debug + SparseElement + 'static> SpaiPreconditioner<F> {
41    /// Create a new SPAI preconditioner from a sparse matrix
42    pub fn new(matrix: &CsrMatrix<F>, options: SpaiOptions) -> SparseResult<Self> {
43        let n = matrix.rows();
44        if n != matrix.cols() {
45            return Err(SparseError::DimensionMismatch {
46                expected: n,
47                found: matrix.cols(),
48            });
49        }
50
51        // For now, we'll implement a simplified version of SPAI
52        // that uses a static sparsity pattern (diagonal + few off-diagonals)
53
54        // Initialize M as identity _matrix in dense format
55        let mut m_dense = vec![vec![F::sparse_zero(); n]; n];
56        for (i, row) in m_dense.iter_mut().enumerate().take(n) {
57            row[i] = F::sparse_one();
58        }
59
60        // For each column of M, solve a least squares problem
61        for j in 0..n {
62            // Define sparsity pattern for column j
63            // For simplicity, we'll use a pattern that includes the diagonal
64            // and a few neighboring elements
65            let mut pattern = vec![j];
66
67            // Add neighbors within distance 2
68            let start = j.saturating_sub(2);
69            let end = (j + 3).min(n);
70
71            for k in start..end {
72                if k != j && pattern.len() < options.max_nnz_per_col {
73                    pattern.push(k);
74                }
75            }
76
77            // Extract the relevant submatrix A_k from A
78            let k = pattern.len();
79            let mut a_k = vec![vec![F::sparse_zero(); k]; n];
80
81            for (col_idx, &col) in pattern.iter().enumerate() {
82                for (row, a_k_row) in a_k.iter_mut().enumerate().take(n) {
83                    let val = matrix.get(row, col);
84                    a_k_row[col_idx] = val;
85                }
86            }
87
88            // Set up right-hand side (j-th unit vector)
89            let mut e_j = vec![F::sparse_zero(); n];
90            e_j[j] = F::sparse_one();
91
92            // Solve least squares problem: minimize ||A_k * m_k - e_j||
93            // For now, use a simple normal equations approach
94            // A_k^T * A_k * m_k = A_k^T * e_j
95
96            // Compute A_k^T * A_k
97            let mut ata = vec![vec![F::sparse_zero(); k]; k];
98            for i in 0..k {
99                for j_inner in 0..k {
100                    let mut sum = F::sparse_zero();
101                    for a_k_row in a_k.iter().take(n) {
102                        sum += a_k_row[i] * a_k_row[j_inner];
103                    }
104                    ata[i][j_inner] = sum;
105                }
106            }
107
108            // Compute A_k^T * e_j
109            let mut atb = vec![F::sparse_zero(); k];
110            atb[..k].copy_from_slice(&a_k[j][..k]);
111
112            // Solve the system (simplified - in practice, use proper solver)
113            let m_k = solve_dense_system(&ata, &atb)?;
114
115            // Update M with the computed values
116            for (idx, &row) in pattern.iter().enumerate() {
117                m_dense[row][j] = m_k[idx];
118            }
119        }
120
121        // Convert dense M to sparse format manually
122        let n = m_dense.len();
123        let mut data = Vec::new();
124        let mut indices = Vec::new();
125        let mut indptr = vec![0];
126
127        for row in m_dense.iter().take(n) {
128            for (j, &val) in row.iter().enumerate().take(n) {
129                if val.abs() > F::epsilon() {
130                    data.push(val);
131                    indices.push(j);
132                }
133            }
134            indptr.push(data.len());
135        }
136
137        let approx_inverse = CsrMatrix::from_raw_csr(data, indptr, indices, (n, n))?;
138
139        Ok(Self { approx_inverse })
140    }
141}
142
143impl<F: Float + NumAssign + Sum + Debug + SparseElement + 'static> LinearOperator<F>
144    for SpaiPreconditioner<F>
145{
146    fn shape(&self) -> (usize, usize) {
147        self.approx_inverse.shape()
148    }
149
150    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
151        if x.len() != self.approx_inverse.cols() {
152            return Err(SparseError::DimensionMismatch {
153                expected: self.approx_inverse.cols(),
154                found: x.len(),
155            });
156        }
157
158        let mut result = vec![F::sparse_zero(); self.approx_inverse.rows()];
159
160        for (row_idx, result_val) in result.iter_mut().enumerate() {
161            for j in self.approx_inverse.indptr[row_idx]..self.approx_inverse.indptr[row_idx + 1] {
162                let col_idx = self.approx_inverse.indices[j];
163                *result_val += self.approx_inverse.data[j] * x[col_idx];
164            }
165        }
166
167        Ok(result)
168    }
169
170    fn has_adjoint(&self) -> bool {
171        true
172    }
173
174    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
175        // For transpose multiplication, we implement A^T * x
176        if x.len() != self.approx_inverse.rows() {
177            return Err(SparseError::DimensionMismatch {
178                expected: self.approx_inverse.rows(),
179                found: x.len(),
180            });
181        }
182
183        let mut result = vec![F::sparse_zero(); self.approx_inverse.cols()];
184
185        for (row_idx, &x_val) in x.iter().enumerate() {
186            for j in self.approx_inverse.indptr[row_idx]..self.approx_inverse.indptr[row_idx + 1] {
187                let col_idx = self.approx_inverse.indices[j];
188                result[col_idx] += self.approx_inverse.data[j] * x_val;
189            }
190        }
191
192        Ok(result)
193    }
194}
195
196/// Solve a dense linear system using Gaussian elimination with partial pivoting
197#[allow(dead_code)]
198fn solve_dense_system<F: Float + NumAssign + SparseElement>(
199    a: &[Vec<F>],
200    b: &[F],
201) -> SparseResult<Vec<F>> {
202    let n = a.len();
203    if n == 0 || n != a[0].len() || n != b.len() {
204        return Err(SparseError::DimensionMismatch {
205            expected: n,
206            found: b.len(),
207        });
208    }
209
210    // Create augmented matrix
211    let mut aug = vec![vec![F::sparse_zero(); n + 1]; n];
212    for i in 0..n {
213        for j in 0..n {
214            aug[i][j] = a[i][j];
215        }
216        aug[i][n] = b[i];
217    }
218
219    // Gaussian elimination with partial pivoting
220    for k in 0..n {
221        // Find pivot
222        let mut max_row = k;
223        let mut max_val = aug[k][k].abs();
224        for (i, aug_row) in aug.iter().enumerate().skip(k + 1).take(n - k - 1) {
225            let val_abs = aug_row[k].abs();
226            if val_abs > max_val {
227                max_val = val_abs;
228                max_row = i;
229            }
230        }
231
232        // Check for singularity
233        if max_val < F::from(1e-14).unwrap() {
234            return Err(SparseError::ValueError(
235                "Matrix is singular or nearly singular".to_string(),
236            ));
237        }
238
239        // Swap rows
240        if max_row != k {
241            aug.swap(k, max_row);
242        }
243
244        // Eliminate below
245        for i in (k + 1)..n {
246            let factor = aug[i][k] / aug[k][k];
247            for j in k..=n {
248                aug[i][j] = aug[i][j] - factor * aug[k][j];
249            }
250        }
251    }
252
253    // Back substitution
254    let mut x = vec![F::sparse_zero(); n];
255    for i in (0..n).rev() {
256        x[i] = aug[i][n];
257        for j in (i + 1)..n {
258            x[i] = x[i] - aug[i][j] * x[j];
259        }
260        x[i] /= aug[i][i];
261    }
262
263    Ok(x)
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use crate::csr::CsrMatrix;
270
271    #[test]
272    fn test_spai_simple() {
273        // Test with a simple tridiagonal matrix
274        // A = [4  -1   0]
275        //     [-1  4  -1]
276        //     [0  -1   4]
277        let data = vec![4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0];
278        let indptr = vec![0, 2, 5, 7];
279        let indices = vec![0, 1, 0, 1, 2, 1, 2];
280        let matrix = CsrMatrix::from_raw_csr(data, indptr, indices, (3, 3)).unwrap();
281
282        let options = SpaiOptions::default();
283        let preconditioner = SpaiPreconditioner::new(&matrix, options).unwrap();
284
285        // Test by applying preconditioner to a vector
286        let b = vec![1.0, 2.0, 3.0];
287        let x = preconditioner.matvec(&b).unwrap();
288
289        // The result should be approximately the solution to Ax = b
290        // For this simple case, we can verify the result is reasonable
291        assert!(x.iter().all(|&xi| xi.is_finite()));
292    }
293
294    #[test]
295    fn test_spai_diagonal() {
296        // Test with a diagonal matrix (should get exact inverse)
297        // A = [2   0   0]
298        //     [0   3   0]
299        //     [0   0   4]
300        let data = vec![2.0, 3.0, 4.0];
301        let indptr = vec![0, 1, 2, 3];
302        let indices = vec![0, 1, 2];
303        let matrix = CsrMatrix::from_raw_csr(data, indptr, indices, (3, 3)).unwrap();
304
305        let options = SpaiOptions::default();
306        let preconditioner = SpaiPreconditioner::new(&matrix, options).unwrap();
307
308        // Apply preconditioner to each unit vector
309        let e1 = vec![1.0, 0.0, 0.0];
310        let e2 = vec![0.0, 1.0, 0.0];
311        let e3 = vec![0.0, 0.0, 1.0];
312
313        let x1 = preconditioner.matvec(&e1).unwrap();
314        let x2 = preconditioner.matvec(&e2).unwrap();
315        let x3 = preconditioner.matvec(&e3).unwrap();
316
317        // For a diagonal matrix, SPAI should recover the exact inverse
318        assert!((x1[0] - 0.5).abs() < 1e-10);
319        assert!((x2[1] - 1.0 / 3.0).abs() < 1e-10);
320        assert!((x3[2] - 0.25).abs() < 1e-10);
321    }
322}