Skip to main content

scirs2_sparse/krylov/
augmented.rs

1//! Augmented Krylov subspace methods.
2//!
3//! These methods augment the standard Krylov space K_m(A, r) with an additional
4//! subspace W (e.g., approximate eigenvectors or solution-space information from
5//! previous solves) to form an *augmented Krylov space*
6//!
7//!   K_m^+(A, r, W) = K_m(A, r) + W.
8//!
9//! This is the theoretical foundation shared by:
10//! - GMRES-DR (augment with harmonic Ritz vectors)
11//! - GCRO-DR (augment with recycled subspace from previous solves)
12//! - Augmented GMRES / LGMRES (augment with error approximations)
13//! - Deflated BiCG methods
14//!
15//! # AugmentedKrylov
16//!
17//! This module provides a general `AugmentedKrylov` solver that accepts an
18//! explicit augmentation subspace and runs GMRES in the augmented space.
19//!
20//! # References
21//!
22//! - Saad, Y. (1997). "Analysis of augmented Krylov subspace methods".
23//!   SIAM J. Matrix Anal. Appl. 18(2), 435-449.
24//! - Baker, A.H., Jessup, E.R., Manteuffel, T. (2005). "A technique for
25//!   accelerating the convergence of restarted GMRES". SIAM J. Matrix Anal.
26//!   Appl. 26(4), 962-984.
27
28use crate::error::SparseError;
29use crate::krylov::gmres_dr::{dot, gram_schmidt_mgs, norm2, solve_least_squares_hessenberg};
30
31/// Configuration for an augmented Krylov solve.
32#[derive(Debug, Clone)]
33pub struct AugmentedKrylovConfig {
34    /// Krylov dimension (not counting augmentation vectors).
35    pub krylov_dim: usize,
36    /// Convergence tolerance.
37    pub tol: f64,
38    /// Maximum number of matrix-vector products.
39    pub max_iter: usize,
40    /// Number of restart cycles.
41    pub max_cycles: usize,
42}
43
44impl Default for AugmentedKrylovConfig {
45    fn default() -> Self {
46        Self {
47            krylov_dim: 20,
48            tol: 1e-10,
49            max_iter: 1000,
50            max_cycles: 50,
51        }
52    }
53}
54
55/// Result from an augmented Krylov solve.
56#[derive(Debug, Clone)]
57pub struct AugmentedKrylovResult {
58    /// Solution vector.
59    pub x: Vec<f64>,
60    /// Final residual norm ||b - Ax||.
61    pub residual_norm: f64,
62    /// Total matrix-vector products.
63    pub iterations: usize,
64    /// Whether the solver converged.
65    pub converged: bool,
66    /// Residual norm at end of each restart cycle.
67    pub residual_history: Vec<f64>,
68    /// Updated augmentation vectors (orthonormal, from the converged Krylov basis).
69    pub new_augmentation: Vec<Vec<f64>>,
70}
71
72/// Augmented Krylov subspace solver (augmented GMRES).
73///
74/// # Overview
75///
76/// Given an augmentation subspace W (external knowledge vectors -- e.g., from
77/// previous solves or approximate eigenvectors), solves A x = b in the space
78/// x_0 + span(W) + K_m(A, r_0).
79///
80/// The solver uses standard GMRES (Arnoldi + least-squares) for the Krylov
81/// portion, and incorporates augmentation vectors by projecting the residual
82/// onto the range of A*W before each GMRES cycle.
83///
84/// After convergence, an updated augmentation subspace is extracted from the
85/// last Krylov basis for use in subsequent solves.
86pub struct AugmentedKrylov {
87    config: AugmentedKrylovConfig,
88}
89
90impl AugmentedKrylov {
91    /// Create an `AugmentedKrylov` solver with the given configuration.
92    pub fn new(config: AugmentedKrylovConfig) -> Self {
93        Self { config }
94    }
95
96    /// Create with default configuration.
97    pub fn with_defaults() -> Self {
98        Self {
99            config: AugmentedKrylovConfig::default(),
100        }
101    }
102
103    /// Solve A x = b with augmented Krylov subspace.
104    ///
105    /// # Arguments
106    ///
107    /// * `matvec` - Closure for the matrix-vector product y = A x.
108    /// * `b` - Right-hand side.
109    /// * `x0` - Optional initial guess.
110    /// * `augmentation` - External vectors to augment the Krylov space.
111    ///   These are incorporated by projecting the residual onto span(A*W) at
112    ///   each restart. Pass an empty slice for standard (non-augmented) GMRES.
113    ///
114    /// # Returns
115    ///
116    /// An `AugmentedKrylovResult` containing the solution and updated augmentation
117    /// vectors.
118    pub fn solve<F>(
119        &self,
120        matvec: F,
121        b: &[f64],
122        x0: Option<&[f64]>,
123        augmentation: &[Vec<f64>],
124    ) -> Result<AugmentedKrylovResult, SparseError>
125    where
126        F: Fn(&[f64]) -> Vec<f64>,
127    {
128        let n = b.len();
129        let mut x = match x0 {
130            Some(v) => v.to_vec(),
131            None => vec![0.0f64; n],
132        };
133
134        let b_norm = norm2(b);
135        let abs_tol = if b_norm > 1e-300 {
136            self.config.tol * b_norm
137        } else {
138            self.config.tol
139        };
140        let mut total_mv = 0usize;
141        let mut residual_history = Vec::new();
142        let mut last_krylov: Vec<Vec<f64>> = Vec::new();
143
144        // Prepare augmentation: orthonormalise the incoming vectors.
145        let mut aug_orth: Vec<Vec<f64>> = augmentation.to_vec();
146        gram_schmidt_mgs(&mut aug_orth);
147        aug_orth.retain(|vi| norm2(vi) > 0.5);
148        let k_aug = aug_orth.len();
149
150        // Pre-compute A*W for augmentation projection (GCRO-style).
151        // We project the residual onto range(A*W) to find the optimal correction
152        // in the augmentation subspace: delta_x = W * (AW)^+ * r.
153        let mut aw: Vec<Vec<f64>> = Vec::with_capacity(k_aug);
154        for j in 0..k_aug {
155            aw.push(matvec(&aug_orth[j]));
156            total_mv += 1;
157        }
158        // Orthonormalise AW columns for stable projection.
159        let mut aw_orth = aw.clone();
160        gram_schmidt_mgs(&mut aw_orth);
161        aw_orth.retain(|vi| norm2(vi) > 0.5);
162
163        for _cycle in 0..self.config.max_cycles {
164            // Compute residual.
165            let ax = matvec(&x);
166            total_mv += 1;
167            let r: Vec<f64> = b.iter().zip(ax.iter()).map(|(bi, axi)| bi - axi).collect();
168            let r_norm = norm2(&r);
169            residual_history.push(r_norm);
170
171            if r_norm <= abs_tol {
172                let new_aug = extract_augmentation(&last_krylov, k_aug, n);
173                return Ok(AugmentedKrylovResult {
174                    x,
175                    residual_norm: r_norm,
176                    iterations: total_mv,
177                    converged: true,
178                    residual_history,
179                    new_augmentation: new_aug,
180                });
181            }
182
183            if total_mv >= self.config.max_iter {
184                break;
185            }
186
187            // --- Augmentation correction (GCRO-style projection) ---
188            // Compute x += W * (AW)^+ * r = W * ((AW_orth)^T * r) (since AW_orth is ON).
189            // But we need the coefficients w.r.t. the original W, not AW_orth.
190            // Use the relation: project r onto each AW column.
191            if k_aug > 0 {
192                // Solve: min ||r - AW * alpha|| for alpha.
193                // Since we have AW (not orthogonalized), use normal equations:
194                // (AW)^T (AW) alpha = (AW)^T r.
195                let mut ata = vec![vec![0.0f64; k_aug]; k_aug];
196                let mut atr = vec![0.0f64; k_aug];
197                for i in 0..k_aug {
198                    atr[i] = dot(&aw[i], &r);
199                    for j in 0..k_aug {
200                        ata[i][j] = dot(&aw[i], &aw[j]);
201                    }
202                }
203                let alpha = solve_small_spd(&ata, &atr, k_aug);
204                for j in 0..k_aug {
205                    for i in 0..n {
206                        x[i] += alpha[j] * aug_orth[j][i];
207                    }
208                }
209            }
210
211            // --- Standard GMRES cycle on the (updated) residual ---
212            let ax2 = matvec(&x);
213            total_mv += 1;
214            let r2: Vec<f64> = b.iter().zip(ax2.iter()).map(|(bi, axi)| bi - axi).collect();
215            let r2_norm = norm2(&r2);
216
217            if r2_norm <= abs_tol {
218                let new_aug = extract_augmentation(&last_krylov, k_aug, n);
219                residual_history.push(r2_norm);
220                return Ok(AugmentedKrylovResult {
221                    x,
222                    residual_norm: r2_norm,
223                    iterations: total_mv,
224                    converged: true,
225                    residual_history,
226                    new_augmentation: new_aug,
227                });
228            }
229
230            // Build standard Arnoldi basis for GMRES.
231            let m = self.config.krylov_dim;
232            let mut v: Vec<Vec<f64>> = vec![vec![0.0f64; n]; m + 1];
233            let mut h: Vec<Vec<f64>> = vec![vec![0.0f64; m]; m + 1];
234
235            // v[0] = r2 / ||r2||
236            let inv_r2 = 1.0 / r2_norm;
237            for l in 0..n {
238                v[0][l] = r2[l] * inv_r2;
239            }
240
241            // Arnoldi iteration: standard GMRES (no augmentation in the basis).
242            let mut j_end = 1;
243            for j in 1..=m {
244                if j == m {
245                    j_end = m;
246                    break;
247                }
248                let w_raw = matvec(&v[j - 1]);
249                total_mv += 1;
250                let mut w = w_raw;
251
252                // Modified Gram-Schmidt orthogonalization.
253                for i in 0..j {
254                    h[i][j - 1] = dot(&w, &v[i]);
255                    for l in 0..n {
256                        w[l] -= h[i][j - 1] * v[i][l];
257                    }
258                }
259                h[j][j - 1] = norm2(&w);
260
261                if h[j][j - 1] > 1e-15 {
262                    let inv = 1.0 / h[j][j - 1];
263                    for l in 0..n {
264                        v[j][l] = w[l] * inv;
265                    }
266                    j_end = j + 1;
267                } else {
268                    j_end = j + 1;
269                    break;
270                }
271
272                if total_mv >= self.config.max_iter {
273                    j_end = j + 1;
274                    break;
275                }
276            }
277
278            let krylov_size = (j_end - 1).max(1).min(h[0].len());
279
280            // Standard GMRES RHS: g = [beta, 0, ..., 0] where beta = ||r2||.
281            let mut g = vec![0.0f64; j_end];
282            g[0] = r2_norm;
283
284            let cols = krylov_size.min(h[0].len());
285            let y = solve_least_squares_hessenberg(&h, &g, cols)?;
286
287            // Update solution: x += V * y.
288            for j in 0..y.len().min(v.len()) {
289                for i in 0..n {
290                    x[i] += y[j] * v[j][i];
291                }
292            }
293
294            // Store basis for augmentation extraction.
295            last_krylov = v[..j_end].to_vec();
296
297            if total_mv >= self.config.max_iter {
298                break;
299            }
300        }
301
302        // Final residual.
303        let ax_fin = matvec(&x);
304        total_mv += 1;
305        let r_fin: Vec<f64> = b
306            .iter()
307            .zip(ax_fin.iter())
308            .map(|(bi, axi)| bi - axi)
309            .collect();
310        let r_fin_norm = norm2(&r_fin);
311        residual_history.push(r_fin_norm);
312
313        let new_aug = extract_augmentation(&last_krylov, k_aug, n);
314
315        Ok(AugmentedKrylovResult {
316            x,
317            residual_norm: r_fin_norm,
318            iterations: total_mv,
319            converged: r_fin_norm <= abs_tol,
320            residual_history,
321            new_augmentation: new_aug,
322        })
323    }
324}
325
326/// Solve a k x k SPD system A x = b.
327/// Falls back to diagonal solve if Cholesky fails.
328pub(crate) fn solve_small_spd(a: &[Vec<f64>], b: &[f64], k: usize) -> Vec<f64> {
329    if k == 0 {
330        return Vec::new();
331    }
332    if k == 1 {
333        let diag = a[0][0];
334        return vec![if diag.abs() > 1e-300 {
335            b[0] / diag
336        } else {
337            0.0
338        }];
339    }
340
341    // Attempt Cholesky: L L^T decomposition.
342    let mut l = vec![vec![0.0f64; k]; k];
343    let mut ok = true;
344    'chol: for i in 0..k {
345        for j in 0..=i {
346            let mut sum = a[i][j];
347            for p in 0..j {
348                sum -= l[i][p] * l[j][p];
349            }
350            if i == j {
351                if sum < 1e-300 {
352                    ok = false;
353                    break 'chol;
354                }
355                l[i][j] = sum.sqrt();
356            } else if l[j][j].abs() > 1e-300 {
357                l[i][j] = sum / l[j][j];
358            } else {
359                ok = false;
360                break 'chol;
361            }
362        }
363    }
364
365    if ok {
366        // Forward substitution: L y = b.
367        let mut y = vec![0.0f64; k];
368        for i in 0..k {
369            let mut s = b[i];
370            for j in 0..i {
371                s -= l[i][j] * y[j];
372            }
373            y[i] = if l[i][i].abs() > 1e-300 {
374                s / l[i][i]
375            } else {
376                0.0
377            };
378        }
379        // Back substitution: L^T x = y.
380        let mut x = vec![0.0f64; k];
381        for i in (0..k).rev() {
382            let mut s = y[i];
383            for j in (i + 1)..k {
384                s -= l[j][i] * x[j];
385            }
386            x[i] = if l[i][i].abs() > 1e-300 {
387                s / l[i][i]
388            } else {
389                0.0
390            };
391        }
392        x
393    } else {
394        // Fallback: diagonal approximation.
395        (0..k)
396            .map(|i| {
397                if a[i][i].abs() > 1e-300 {
398                    b[i] / a[i][i]
399                } else {
400                    0.0
401                }
402            })
403            .collect()
404    }
405}
406
407/// Extract updated augmentation vectors from the last Krylov basis.
408/// Takes the first few vectors from the Krylov basis.
409fn extract_augmentation(krylov: &[Vec<f64>], k_aug: usize, _n: usize) -> Vec<Vec<f64>> {
410    if krylov.is_empty() || k_aug == 0 {
411        return Vec::new();
412    }
413    let m = krylov.len();
414    let take = k_aug.min(m);
415    // Return the first `take` Krylov vectors.
416    let mut new_vecs: Vec<Vec<f64>> = krylov[..take].to_vec();
417    gram_schmidt_mgs(&mut new_vecs);
418    new_vecs.retain(|vi| norm2(vi) > 0.5);
419    new_vecs
420}
421
422// ---------------------------------------------------------------------------
423// Tests
424// ---------------------------------------------------------------------------
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429
430    fn diag_mv(diag: Vec<f64>) -> impl Fn(&[f64]) -> Vec<f64> {
431        move |x: &[f64]| x.iter().zip(diag.iter()).map(|(xi, di)| xi * di).collect()
432    }
433
434    #[test]
435    fn test_augmented_krylov_no_augmentation() {
436        // Without augmentation, this should behave like standard GMRES.
437        let n = 8;
438        let diag: Vec<f64> = (1..=n).map(|i| i as f64).collect();
439        let b = vec![1.0f64; n];
440
441        let solver = AugmentedKrylov::new(AugmentedKrylovConfig {
442            krylov_dim: 6,
443            tol: 1e-12,
444            max_iter: 300,
445            max_cycles: 20,
446        });
447
448        let result = solver
449            .solve(diag_mv(diag.clone()), &b, None, &[])
450            .expect("augmented krylov solve failed");
451
452        assert!(
453            result.converged,
454            "should converge without augmentation: residual = {:.3e}",
455            result.residual_norm
456        );
457    }
458
459    #[test]
460    fn test_augmented_krylov_with_augmentation() {
461        // Provide augmentation vectors as the first two standard basis vectors.
462        let n = 10;
463        let diag: Vec<f64> = (1..=n).map(|i| i as f64).collect();
464        let b = vec![1.0f64; n];
465
466        // Augmentation: e_0, e_1 (should align with first two basis vectors of solution).
467        let aug = vec![
468            {
469                let mut v = vec![0.0f64; n];
470                v[0] = 1.0;
471                v
472            },
473            {
474                let mut v = vec![0.0f64; n];
475                v[1] = 1.0;
476                v
477            },
478        ];
479
480        let solver = AugmentedKrylov::new(AugmentedKrylovConfig {
481            krylov_dim: 8,
482            tol: 1e-12,
483            max_iter: 300,
484            max_cycles: 30,
485        });
486
487        let result = solver
488            .solve(diag_mv(diag), &b, None, &aug)
489            .expect("augmented krylov with augmentation failed");
490
491        assert!(
492            result.converged,
493            "should converge with augmentation: residual = {:.3e}",
494            result.residual_norm
495        );
496    }
497
498    #[test]
499    fn test_augmented_result_new_augmentation_populated() {
500        let n = 6;
501        let diag: Vec<f64> = (1..=n).map(|i| i as f64).collect();
502        let b = vec![1.0f64; n];
503
504        let aug = vec![{
505            let mut v = vec![0.0f64; n];
506            v[0] = 1.0;
507            v
508        }];
509
510        let solver = AugmentedKrylov::with_defaults();
511        let result = solver
512            .solve(diag_mv(diag), &b, None, &aug)
513            .expect("solve failed");
514
515        // new_augmentation may be empty if no Krylov basis was built, but should not panic.
516        assert!(result.converged || result.residual_norm < 1e-8);
517    }
518
519    #[test]
520    fn test_augmented_config_default() {
521        let cfg = AugmentedKrylovConfig::default();
522        assert_eq!(cfg.krylov_dim, 20);
523        assert!(cfg.tol > 0.0);
524        assert!(cfg.max_iter > 0);
525    }
526}