1use crate::error::{SparseError, SparseResult};
2use crate::linalg::interface::LinearOperator;
3use scirs2_core::numeric::{Float, NumAssign, SparseElement};
4use std::fmt::Display;
5use std::iter::Sum;
6
7#[derive(Debug, Clone)]
9pub struct QMRResult<F> {
10 pub x: Vec<F>,
11 pub iterations: usize,
12 pub residual_norm: F,
13 pub converged: bool,
14 pub message: String,
15}
16
17pub struct QMROptions<F> {
19 pub max_iter: usize,
20 pub rtol: F,
21 pub atol: F,
22 pub x0: Option<Vec<F>>,
23 pub left_preconditioner: Option<Box<dyn LinearOperator<F>>>,
24 pub right_preconditioner: Option<Box<dyn LinearOperator<F>>>,
25}
26
27impl<F: Float> Default for QMROptions<F> {
28 fn default() -> Self {
29 Self {
30 max_iter: 1000,
31 rtol: F::from(1e-8).unwrap(),
32 atol: F::from(1e-12).unwrap(),
33 x0: None,
34 left_preconditioner: None,
35 right_preconditioner: None,
36 }
37 }
38}
39
40#[allow(dead_code)]
45pub fn qmr<F>(
46 a: &dyn LinearOperator<F>,
47 b: &[F],
48 options: QMROptions<F>,
49) -> SparseResult<QMRResult<F>>
50where
51 F: Float + SparseElement + NumAssign + Sum + Display + 'static,
52{
53 let n = b.len();
54
55 if a.shape().0 != n || a.shape().1 != n {
57 return Err(SparseError::DimensionMismatch {
58 expected: n,
59 found: a.shape().0,
60 });
61 }
62
63 let mut x = options.x0.unwrap_or_else(|| vec![F::sparse_zero(); n]);
65
66 let mut r = if !x.iter().all(|&xi| xi == F::sparse_zero()) {
68 let ax = a.matvec(&x)?;
69 vec_sub(b, &ax)
70 } else {
71 b.to_vec()
72 };
73
74 if let Some(ref ml) = options.left_preconditioner {
76 r = ml.matvec(&r)?;
77 }
78
79 let r_tilde = r.clone();
81
82 let mut p = vec![F::sparse_zero(); n];
84 let mut p_tilde = vec![F::sparse_zero(); n];
85 let mut q = vec![F::sparse_zero(); n];
86 let mut q_tilde = vec![F::sparse_zero(); n];
87
88 let mut rho = F::sparse_one();
90 let mut rho_old;
91 let mut alpha = F::sparse_zero();
92 let mut omega = F::sparse_one();
93
94 let bnorm = norm2(b);
96 let mut rnorm = norm2(&r);
97 let tol = options.atol + options.rtol * bnorm;
98
99 if rnorm < tol {
101 return Ok(QMRResult {
102 x,
103 iterations: 0,
104 residual_norm: rnorm,
105 converged: true,
106 message: "Converged at initial guess".to_string(),
107 });
108 }
109
110 for iter in 0..options.max_iter {
112 rho_old = rho;
114
115 rho = dot(&r_tilde, &r);
117
118 if rho.abs() < F::epsilon() * F::from(10).unwrap() {
120 return Ok(QMRResult {
121 x,
122 iterations: iter,
123 residual_norm: rnorm,
124 converged: false,
125 message: "Breakdown: rho = 0".to_string(),
126 });
127 }
128
129 let beta = if iter == 0 {
131 F::sparse_zero()
132 } else {
133 (rho / rho_old) * (alpha / omega)
134 };
135
136 p = if iter == 0 {
138 r.clone()
139 } else {
140 vec_add(&r, &vec_scaled(&vec_sub(&p, &vec_scaled(&q, omega)), beta))
141 };
142
143 let p_prec = if let Some(ref mr) = options.right_preconditioner {
145 mr.matvec(&p)?
146 } else {
147 p.clone()
148 };
149
150 q = a.matvec(&p_prec)?;
151
152 if let Some(ref ml) = options.left_preconditioner {
153 q = ml.matvec(&q)?;
154 }
155
156 p_tilde = if iter == 0 {
158 r_tilde.clone()
159 } else {
160 let diff = vec_sub(&p_tilde, &vec_scaled(&q_tilde, omega));
161 vec_add(&r_tilde, &vec_scaled(&diff, beta))
162 };
163
164 let p_tilde_prec = if let Some(ref ml) = options.left_preconditioner {
166 ml.rmatvec(&p_tilde)?
167 } else {
168 p_tilde.clone()
169 };
170
171 q_tilde = a.rmatvec(&p_tilde_prec)?;
172
173 if let Some(ref mr) = options.right_preconditioner {
174 q_tilde = mr.rmatvec(&q_tilde)?;
175 }
176
177 let dot_pq = dot(&p_tilde, &q);
179 if dot_pq.abs() < F::epsilon() * F::from(10).unwrap() {
180 return Ok(QMRResult {
181 x,
182 iterations: iter,
183 residual_norm: rnorm,
184 converged: false,
185 message: "Breakdown: <p_tilde, q> = 0".to_string(),
186 });
187 }
188
189 alpha = rho / dot_pq;
190
191 let s = vec_sub(&r, &vec_scaled(&q, alpha));
193 let _s_tilde = vec_sub(&r_tilde, &vec_scaled(&q_tilde, alpha));
194
195 let s_prec = if let Some(ref mr) = options.right_preconditioner {
197 mr.matvec(&s)?
198 } else {
199 s.clone()
200 };
201
202 let t = a.matvec(&s_prec)?;
203 let t = if let Some(ref ml) = options.left_preconditioner {
204 ml.matvec(&t)?
205 } else {
206 t
207 };
208
209 let dot_tt = dot(&t, &t);
211 if dot_tt == F::sparse_zero() {
212 omega = F::sparse_zero();
213 } else {
214 omega = dot(&t, &s) / dot_tt;
215 }
216
217 x = vec_add(&x, &vec_scaled(&p_prec, alpha));
219 x = vec_add(&x, &vec_scaled(&s_prec, omega));
220
221 r = vec_sub(&s, &vec_scaled(&t, omega));
223
224 rnorm = norm2(&r);
226
227 if rnorm < tol {
229 return Ok(QMRResult {
230 x,
231 iterations: iter + 1,
232 residual_norm: rnorm,
233 converged: true,
234 message: format!("Converged in {} iterations", iter + 1),
235 });
236 }
237
238 if omega.abs() < F::epsilon() {
240 return Ok(QMRResult {
241 x,
242 iterations: iter + 1,
243 residual_norm: rnorm,
244 converged: false,
245 message: "Breakdown: omega = 0".to_string(),
246 });
247 }
248 }
249
250 Ok(QMRResult {
251 x,
252 iterations: options.max_iter,
253 residual_norm: rnorm,
254 converged: false,
255 message: format!(
256 "Did not converge in {} iterations. Residual: {}",
257 options.max_iter, rnorm
258 ),
259 })
260}
261
262#[allow(dead_code)]
264fn dot<F: Float + Sum>(a: &[F], b: &[F]) -> F {
265 a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
266}
267
268#[allow(dead_code)]
269fn norm2<F: Float + Sum>(v: &[F]) -> F {
270 v.iter().map(|&vi| vi * vi).sum::<F>().sqrt()
271}
272
273#[allow(dead_code)]
274fn vec_add<F: Float>(a: &[F], b: &[F]) -> Vec<F> {
275 a.iter().zip(b.iter()).map(|(&ai, &bi)| ai + bi).collect()
276}
277
278#[allow(dead_code)]
279fn vec_sub<F: Float>(a: &[F], b: &[F]) -> Vec<F> {
280 a.iter().zip(b.iter()).map(|(&ai, &bi)| ai - bi).collect()
281}
282
283#[allow(dead_code)]
284fn vec_scaled<F: Float>(v: &[F], s: F) -> Vec<F> {
285 v.iter().map(|&vi| vi * s).collect()
286}
287
288#[cfg(test)]
301mod tests {
302 use super::*;
303 use crate::linalg::interface::{DiagonalOperator, IdentityOperator};
304
305 #[test]
306 fn test_qmr_identity() {
307 let identity = IdentityOperator::<f64>::new(3);
309 let b = vec![1.0, 2.0, 3.0];
310 let options = QMROptions::default();
311
312 let result = qmr(&identity, &b, options).unwrap();
313 assert!(result.converged);
314 assert_eq!(result.iterations, 1); for (i, &b_val) in b.iter().enumerate() {
316 assert!((result.x[i] - b_val).abs() < 1e-10);
317 }
318 }
319
320 #[test]
321 fn test_qmr_diagonal() {
322 let diag = vec![2.0, 3.0, 4.0];
324 let diagonal = DiagonalOperator::new(diag.clone());
325 let b = vec![2.0, 6.0, 8.0]; let expected = [1.0, 2.0, 2.0];
327
328 let options = QMROptions {
329 rtol: 1e-10,
330 atol: 1e-12,
331 ..Default::default()
332 };
333
334 let result = qmr(&diagonal, &b, options).unwrap();
335 assert!(result.converged);
336 assert!(result.iterations <= 10); for (i, &exp_val) in expected.iter().enumerate() {
338 assert!(
339 (result.x[i] - exp_val).abs() < 1e-9,
340 "x[{}] = {} != {}",
341 i,
342 result.x[i],
343 exp_val
344 );
345 }
346 }
347
348 #[test]
349 fn test_qmr_with_initial_guess() {
350 let identity = IdentityOperator::<f64>::new(3);
352 let b = vec![1.0, 2.0, 3.0];
353 let x0 = vec![0.9, 1.9, 2.9]; let options = QMROptions {
356 x0: Some(x0),
357 rtol: 1e-10,
358 atol: 1e-12,
359 ..Default::default()
360 };
361
362 let result = qmr(&identity, &b, options).unwrap();
363 assert!(result.converged);
364 assert!(result.iterations <= 1); for (i, &b_val) in b.iter().enumerate() {
366 assert!((result.x[i] - b_val).abs() < 1e-10);
367 }
368 }
369
370 #[test]
371 fn test_qmr_max_iterations() {
372 let diag = vec![1e-8, 1.0, 1.0]; let diagonal = DiagonalOperator::new(diag.clone());
375 let b = vec![1.0, 1.0, 1.0];
376
377 let options = QMROptions {
378 max_iter: 5,
379 rtol: 1e-10,
380 atol: 1e-12,
381 ..Default::default()
382 };
383
384 let result = qmr(&diagonal, &b, options).unwrap();
385 if !result.converged {
386 assert_eq!(result.iterations, 5);
387 assert!(result.message.contains("Did not converge"));
388 }
389 }
390}