scirs2_sparse/linalg/
cgs.rs

1use crate::error::{SparseError, SparseResult};
2use crate::linalg::interface::LinearOperator;
3use crate::linalg::iterative::{dot, norm2, BiCGOptions, IterationResult};
4use scirs2_core::numeric::{Float, NumAssign, SparseElement};
5use std::iter::Sum;
6
7/// Options for CGS solver
8pub type CGSOptions<F> = BiCGOptions<F>;
9pub type CGSResult<F> = IterationResult<F>;
10
11/// Conjugate Gradient Squared solver (CGS)
12///
13/// Implementation following the algorithm from "Templates for the Solution of Linear Systems"
14/// by Barrett et al. This is for non-symmetric linear systems.
15#[allow(dead_code)]
16pub fn cgs<F>(
17    a: &dyn LinearOperator<F>,
18    b: &[F],
19    options: CGSOptions<F>,
20) -> SparseResult<CGSResult<F>>
21where
22    F: Float + NumAssign + Sum + SparseElement + 'static,
23{
24    let (rows, cols) = a.shape();
25    if rows != cols {
26        return Err(SparseError::ValueError(
27            "Matrix must be square for CGS solver".to_string(),
28        ));
29    }
30    if b.len() != rows {
31        return Err(SparseError::DimensionMismatch {
32            expected: rows,
33            found: b.len(),
34        });
35    }
36
37    let n = rows;
38
39    // Initialize solution
40    let mut x: Vec<F> = match &options.x0 {
41        Some(x0) => {
42            if x0.len() != n {
43                return Err(SparseError::DimensionMismatch {
44                    expected: n,
45                    found: x0.len(),
46                });
47            }
48            x0.clone()
49        }
50        None => vec![F::sparse_zero(); n],
51    };
52
53    // Compute initial residual: r = b - A*x
54    let ax = a.matvec(&x)?;
55    let mut r: Vec<F> = b.iter().zip(&ax).map(|(&bi, &axi)| bi - axi).collect();
56
57    // Check if initial guess is solution
58    let mut rnorm = norm2(&r);
59    let bnorm = norm2(b);
60    let tolerance = F::max(options.atol, options.rtol * bnorm);
61
62    if rnorm <= tolerance {
63        return Ok(CGSResult {
64            x,
65            iterations: 0,
66            residual_norm: rnorm,
67            converged: true,
68            message: "Converged with initial guess".to_string(),
69        });
70    }
71
72    // Choose arbitrary r̃ (usually r̃ = r)
73    let r_tilde = r.clone();
74
75    // Initialize vectors
76    let mut u = vec![F::sparse_zero(); n];
77    let mut p = vec![F::sparse_zero(); n];
78    let mut q = vec![F::sparse_zero(); n];
79
80    let mut rho = F::sparse_one();
81    let mut iterations = 0;
82
83    // Main CGS iteration
84    while iterations < options.max_iter {
85        // Compute ρ = (r̃, r)
86        let rho_new = dot(&r_tilde, &r);
87
88        // Check for breakdown
89        if rho_new.abs() < F::epsilon() * F::from(10).expect("Failed to convert constant to float")
90        {
91            return Ok(CGSResult {
92                x,
93                iterations,
94                residual_norm: rnorm,
95                converged: false,
96                message: "CGS breakdown: rho ≈ 0".to_string(),
97            });
98        }
99
100        // Compute β = ρ_i / ρ_{i-1}
101        let beta = if iterations == 0 {
102            F::sparse_zero()
103        } else {
104            rho_new / rho
105        };
106
107        // Update u and p
108        for i in 0..n {
109            u[i] = r[i] + beta * q[i];
110            p[i] = u[i] + beta * (q[i] + beta * p[i]);
111        }
112
113        // Apply right preconditioner if provided
114        let p_prec = if let Some(m) = &options.right_preconditioner {
115            m.matvec(&p)?
116        } else {
117            p.clone()
118        };
119
120        // v = A * M^{-1} * p
121        let v = a.matvec(&p_prec)?;
122
123        // σ = (r̃, v)
124        let sigma = dot(&r_tilde, &v);
125
126        // Check for breakdown
127        if sigma.abs() < F::epsilon() * F::from(10).expect("Failed to convert constant to float") {
128            return Ok(CGSResult {
129                x,
130                iterations,
131                residual_norm: rnorm,
132                converged: false,
133                message: "CGS breakdown: sigma ≈ 0".to_string(),
134            });
135        }
136
137        // α = ρ / σ
138        let alpha = rho_new / sigma;
139
140        // Update q
141        for i in 0..n {
142            q[i] = u[i] - alpha * v[i];
143        }
144
145        // Compute u + q
146        let u_plus_q: Vec<F> = u.iter().zip(&q).map(|(&ui, &qi)| ui + qi).collect();
147
148        // Apply right preconditioner if provided
149        let u_plus_q_prec = if let Some(m) = &options.right_preconditioner {
150            m.matvec(&u_plus_q)?
151        } else {
152            u_plus_q
153        };
154
155        // Update x
156        for i in 0..n {
157            x[i] += alpha * u_plus_q_prec[i];
158        }
159
160        // Apply right preconditioner to q
161        let q_prec = if let Some(m) = &options.right_preconditioner {
162            m.matvec(&q)?
163        } else {
164            q.clone()
165        };
166
167        // Compute A * M^{-1} * q
168        let aq = a.matvec(&q_prec)?;
169
170        // Update r
171        for i in 0..n {
172            r[i] -= alpha * (v[i] + aq[i]);
173        }
174
175        rho = rho_new;
176        iterations += 1;
177
178        // Check convergence
179        rnorm = norm2(&r);
180        if rnorm <= tolerance {
181            break;
182        }
183    }
184
185    Ok(CGSResult {
186        x,
187        iterations,
188        residual_norm: rnorm,
189        converged: rnorm <= tolerance,
190        message: if rnorm <= tolerance {
191            "Converged".to_string()
192        } else {
193            "Maximum iterations reached".to_string()
194        },
195    })
196}