1use crate::error::{SparseError, SparseResult};
4use crate::linalg::interface::LinearOperator;
5use scirs2_core::numeric::{Float, NumAssign, SparseElement};
6use std::fmt::Debug;
7use std::iter::Sum;
8
9pub struct LGMRESOptions<F> {
11 pub max_iter: usize,
13 pub rtol: F,
15 pub atol: F,
17 pub inner_m: usize,
19 pub outer_k: usize,
21 pub x0: Option<Vec<F>>,
23 pub preconditioner: Option<Box<dyn LinearOperator<F>>>,
25}
26
27impl<F: Float> Default for LGMRESOptions<F> {
28 fn default() -> Self {
29 Self {
30 max_iter: 500,
31 rtol: F::from(1e-8).unwrap(),
32 atol: F::from(1e-8).unwrap(),
33 inner_m: 30,
34 outer_k: 3,
35 x0: None,
36 preconditioner: None,
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct LGMRESResult<F> {
44 pub x: Vec<F>,
46 pub iterations: usize,
48 pub residual_norm: F,
50 pub converged: bool,
52}
53
54#[allow(dead_code)]
59pub fn lgmres<F>(
60 a: &dyn LinearOperator<F>,
61 b: &[F],
62 options: LGMRESOptions<F>,
63) -> SparseResult<LGMRESResult<F>>
64where
65 F: Float + SparseElement + NumAssign + Sum + Debug + 'static,
66{
67 let (m, n) = a.shape();
68 if m != n {
69 return Err(SparseError::DimensionMismatch {
70 expected: m,
71 found: n,
72 });
73 }
74
75 if b.len() != n {
76 return Err(SparseError::DimensionMismatch {
77 expected: n,
78 found: b.len(),
79 });
80 }
81
82 let mut x = options
84 .x0
85 .clone()
86 .unwrap_or_else(|| vec![F::sparse_zero(); n]);
87 let b_norm = b.iter().map(|&bi| bi * bi).sum::<F>().sqrt();
88
89 if b_norm < options.atol {
90 return Ok(LGMRESResult {
91 x,
92 iterations: 0,
93 residual_norm: F::sparse_zero(),
94 converged: true,
95 });
96 }
97
98 let mut r = if let Some(ref m) = options.preconditioner {
100 let ax = a.matvec(&x)?;
101 let residual: Vec<F> = b
102 .iter()
103 .zip(ax.iter())
104 .map(|(&bi, &axi)| bi - axi)
105 .collect();
106 m.matvec(&residual)?
107 } else {
108 let ax = a.matvec(&x)?;
109 b.iter()
110 .zip(ax.iter())
111 .map(|(&bi, &axi)| bi - axi)
112 .collect()
113 };
114
115 let mut r_norm = r.iter().map(|&ri| ri * ri).sum::<F>().sqrt();
116
117 if r_norm < options.atol || r_norm / b_norm < options.rtol {
118 return Ok(LGMRESResult {
119 x,
120 iterations: 0,
121 residual_norm: r_norm,
122 converged: true,
123 });
124 }
125
126 let mut augmented_vectors: Vec<Vec<F>> = Vec::new();
128 let mut total_iter = 0;
129
130 for _ in 0..options.max_iter {
131 let (y, new_r_norm, v_list) = inner_gmres(
133 a,
134 &r,
135 options.inner_m,
136 &augmented_vectors,
137 options.preconditioner.as_deref(),
138 )?;
139
140 for (xi, &yi) in x.iter_mut().zip(y.iter()) {
142 *xi += yi;
143 }
144
145 total_iter += 1;
146
147 if let Some(ref m) = options.preconditioner {
149 let ax = a.matvec(&x)?;
150 let residual: Vec<F> = b
151 .iter()
152 .zip(ax.iter())
153 .map(|(&bi, &axi)| bi - axi)
154 .collect();
155 r = m.matvec(&residual)?;
156 } else {
157 let ax = a.matvec(&x)?;
158 r = b
159 .iter()
160 .zip(ax.iter())
161 .map(|(&bi, &axi)| bi - axi)
162 .collect();
163 }
164
165 r_norm = r.iter().map(|&ri| ri * ri).sum::<F>().sqrt();
166
167 if r_norm < options.atol || r_norm / b_norm < options.rtol {
169 return Ok(LGMRESResult {
170 x,
171 iterations: total_iter,
172 residual_norm: r_norm,
173 converged: true,
174 });
175 }
176
177 for v in v_list {
179 augmented_vectors.push(v);
180 }
181 if augmented_vectors.len() > options.outer_k {
182 augmented_vectors.drain(0..augmented_vectors.len() - options.outer_k);
183 }
184 }
185
186 Ok(LGMRESResult {
187 x,
188 iterations: total_iter,
189 residual_norm: r_norm,
190 converged: false,
191 })
192}
193
194#[allow(dead_code)]
196fn inner_gmres<F>(
197 a: &dyn LinearOperator<F>,
198 r0: &[F],
199 m: usize,
200 augmented_vectors: &[Vec<F>],
201 preconditioner: Option<&dyn LinearOperator<F>>,
202) -> SparseResult<(Vec<F>, F, Vec<Vec<F>>)>
203where
204 F: Float + SparseElement + NumAssign + Sum + Debug + 'static,
205{
206 let n = r0.len();
207
208 let mut v = vec![vec![F::sparse_zero(); n]; m + 1];
211 let r0_norm = r0.iter().map(|&ri| ri * ri).sum::<F>().sqrt();
212
213 if r0_norm < F::epsilon() {
214 return Ok((vec![F::sparse_zero(); n], F::sparse_zero(), vec![]));
215 }
216
217 v[0] = r0.iter().map(|&ri| ri / r0_norm).collect();
218
219 let mut h = vec![vec![F::sparse_zero(); m]; m + 1];
220 let mut s = vec![F::sparse_zero(); m + 1];
221 let mut c = vec![F::sparse_zero(); m + 1];
222 let mut beta = vec![F::sparse_zero(); m + 2];
223 beta[0] = r0_norm;
224
225 let mut k = 0; for j in 0..m {
228 let w = if let Some(prec) = preconditioner {
230 let av = a.matvec(&v[j])?;
231 prec.matvec(&av)?
232 } else {
233 a.matvec(&v[j])?
234 };
235
236 let mut w_orth = w.clone();
238 for i in 0..=j {
239 let h_ij = w
240 .iter()
241 .zip(v[i].iter())
242 .map(|(&wi, &vi)| wi * vi)
243 .sum::<F>();
244 h[i][j] = h_ij;
245 for (idx, w_elem) in w_orth.iter_mut().enumerate().take(n) {
246 *w_elem -= h_ij * v[i][idx];
247 }
248 }
249
250 let h_jp1_j = w_orth.iter().map(|&wi| wi * wi).sum::<F>().sqrt();
251
252 if h_jp1_j > F::epsilon() {
253 h[j + 1][j] = h_jp1_j;
254 v[j + 1] = w_orth.iter().map(|&wi| wi / h_jp1_j).collect();
255 } else {
256 k = j + 1;
258 break;
259 }
260
261 for i in 0..j {
263 let temp = c[i] * h[i][j] + s[i] * h[i + 1][j];
264 h[i + 1][j] = -s[i] * h[i][j] + c[i] * h[i + 1][j];
265 h[i][j] = temp;
266 }
267
268 let h_jj = h[j][j];
270 let h_jp1_j = h[j + 1][j];
271 let rho = (h_jj * h_jj + h_jp1_j * h_jp1_j).sqrt();
272
273 if rho > F::epsilon() {
274 c[j] = h_jj / rho;
275 s[j] = h_jp1_j / rho;
276
277 h[j][j] = c[j] * h_jj + s[j] * h_jp1_j;
278 h[j + 1][j] = F::sparse_zero();
279
280 beta[j + 1] = -s[j] * beta[j];
282 beta[j] = c[j] * beta[j];
283
284 k = j + 1; if beta[j + 1].abs() < F::from(1e-10).unwrap() {
288 break;
289 }
290 } else {
291 k = j;
293 break;
294 }
295 }
296
297 if k == 0 {
299 return Ok((vec![F::sparse_zero(); n], r0_norm, vec![]));
300 }
301
302 let mut y = vec![F::sparse_zero(); k];
304 for i in (0..k).rev() {
305 y[i] = beta[i];
306 for j in (i + 1)..k {
307 y[i] = y[i] - h[i][j] * y[j];
308 }
309 if h[i][i].abs() > F::epsilon() {
310 y[i] /= h[i][i];
311 } else {
312 y[i] = F::sparse_zero();
313 }
314 }
315
316 let mut x = vec![F::sparse_zero(); n];
318 for i in 0..k {
319 for (j, x_val) in x.iter_mut().enumerate().take(n) {
320 *x_val += y[i] * v[i][j];
321 }
322 }
323
324 let v_list: Vec<Vec<F>> = if k > 1 {
326 v.into_iter().skip(1).take(k - 1).collect()
327 } else {
328 vec![]
329 };
330
331 Ok((x, beta[k].abs(), v_list))
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use crate::csr::CsrMatrix;
338 use crate::linalg::interface::{AsLinearOperator, IdentityOperator};
339
340 #[test]
341 fn test_lgmres_identity() {
342 let a = IdentityOperator::new(3);
344 let b = vec![1.0, 2.0, 3.0];
345 let options = LGMRESOptions::default();
346
347 let result = lgmres(&a, &b, options).unwrap();
348
349 assert!(result.converged);
350 assert!((result.x[0] - 1.0).abs() < 1e-10);
351 assert!((result.x[1] - 2.0).abs() < 1e-10);
352 assert!((result.x[2] - 3.0).abs() < 1e-10);
353 }
354
355 #[test]
356 fn test_lgmres_spd_matrix() {
357 let data = vec![4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0];
359 let indptr = vec![0, 2, 5, 7];
360 let indices = vec![0, 1, 0, 1, 2, 1, 2];
361 let matrix = CsrMatrix::from_raw_csr(data, indptr, indices, (3, 3)).unwrap();
362 let linear_op = matrix.as_linear_operator();
363
364 let b = vec![1.0, 2.0, 3.0];
365 let options = LGMRESOptions::default();
366
367 let result = lgmres(linear_op.as_ref(), &b, options).unwrap();
368
369 assert!(result.converged);
370 assert!(result.residual_norm < 1e-8);
371 }
372}