1use crate::error::SparseError;
29use crate::krylov::gmres_dr::{dot, gram_schmidt_mgs, norm2, solve_least_squares_hessenberg};
30
31#[derive(Debug, Clone)]
33pub struct AugmentedKrylovConfig {
34 pub krylov_dim: usize,
36 pub tol: f64,
38 pub max_iter: usize,
40 pub max_cycles: usize,
42}
43
44impl Default for AugmentedKrylovConfig {
45 fn default() -> Self {
46 Self {
47 krylov_dim: 20,
48 tol: 1e-10,
49 max_iter: 1000,
50 max_cycles: 50,
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct AugmentedKrylovResult {
58 pub x: Vec<f64>,
60 pub residual_norm: f64,
62 pub iterations: usize,
64 pub converged: bool,
66 pub residual_history: Vec<f64>,
68 pub new_augmentation: Vec<Vec<f64>>,
70}
71
72pub struct AugmentedKrylov {
87 config: AugmentedKrylovConfig,
88}
89
90impl AugmentedKrylov {
91 pub fn new(config: AugmentedKrylovConfig) -> Self {
93 Self { config }
94 }
95
96 pub fn with_defaults() -> Self {
98 Self {
99 config: AugmentedKrylovConfig::default(),
100 }
101 }
102
103 pub fn solve<F>(
119 &self,
120 matvec: F,
121 b: &[f64],
122 x0: Option<&[f64]>,
123 augmentation: &[Vec<f64>],
124 ) -> Result<AugmentedKrylovResult, SparseError>
125 where
126 F: Fn(&[f64]) -> Vec<f64>,
127 {
128 let n = b.len();
129 let mut x = match x0 {
130 Some(v) => v.to_vec(),
131 None => vec![0.0f64; n],
132 };
133
134 let b_norm = norm2(b);
135 let abs_tol = if b_norm > 1e-300 {
136 self.config.tol * b_norm
137 } else {
138 self.config.tol
139 };
140 let mut total_mv = 0usize;
141 let mut residual_history = Vec::new();
142 let mut last_krylov: Vec<Vec<f64>> = Vec::new();
143
144 let mut aug_orth: Vec<Vec<f64>> = augmentation.to_vec();
146 gram_schmidt_mgs(&mut aug_orth);
147 aug_orth.retain(|vi| norm2(vi) > 0.5);
148 let k_aug = aug_orth.len();
149
150 let mut aw: Vec<Vec<f64>> = Vec::with_capacity(k_aug);
154 for j in 0..k_aug {
155 aw.push(matvec(&aug_orth[j]));
156 total_mv += 1;
157 }
158 let mut aw_orth = aw.clone();
160 gram_schmidt_mgs(&mut aw_orth);
161 aw_orth.retain(|vi| norm2(vi) > 0.5);
162
163 for _cycle in 0..self.config.max_cycles {
164 let ax = matvec(&x);
166 total_mv += 1;
167 let r: Vec<f64> = b.iter().zip(ax.iter()).map(|(bi, axi)| bi - axi).collect();
168 let r_norm = norm2(&r);
169 residual_history.push(r_norm);
170
171 if r_norm <= abs_tol {
172 let new_aug = extract_augmentation(&last_krylov, k_aug, n);
173 return Ok(AugmentedKrylovResult {
174 x,
175 residual_norm: r_norm,
176 iterations: total_mv,
177 converged: true,
178 residual_history,
179 new_augmentation: new_aug,
180 });
181 }
182
183 if total_mv >= self.config.max_iter {
184 break;
185 }
186
187 if k_aug > 0 {
192 let mut ata = vec![vec![0.0f64; k_aug]; k_aug];
196 let mut atr = vec![0.0f64; k_aug];
197 for i in 0..k_aug {
198 atr[i] = dot(&aw[i], &r);
199 for j in 0..k_aug {
200 ata[i][j] = dot(&aw[i], &aw[j]);
201 }
202 }
203 let alpha = solve_small_spd(&ata, &atr, k_aug);
204 for j in 0..k_aug {
205 for i in 0..n {
206 x[i] += alpha[j] * aug_orth[j][i];
207 }
208 }
209 }
210
211 let ax2 = matvec(&x);
213 total_mv += 1;
214 let r2: Vec<f64> = b.iter().zip(ax2.iter()).map(|(bi, axi)| bi - axi).collect();
215 let r2_norm = norm2(&r2);
216
217 if r2_norm <= abs_tol {
218 let new_aug = extract_augmentation(&last_krylov, k_aug, n);
219 residual_history.push(r2_norm);
220 return Ok(AugmentedKrylovResult {
221 x,
222 residual_norm: r2_norm,
223 iterations: total_mv,
224 converged: true,
225 residual_history,
226 new_augmentation: new_aug,
227 });
228 }
229
230 let m = self.config.krylov_dim;
232 let mut v: Vec<Vec<f64>> = vec![vec![0.0f64; n]; m + 1];
233 let mut h: Vec<Vec<f64>> = vec![vec![0.0f64; m]; m + 1];
234
235 let inv_r2 = 1.0 / r2_norm;
237 for l in 0..n {
238 v[0][l] = r2[l] * inv_r2;
239 }
240
241 let mut j_end = 1;
243 for j in 1..=m {
244 if j == m {
245 j_end = m;
246 break;
247 }
248 let w_raw = matvec(&v[j - 1]);
249 total_mv += 1;
250 let mut w = w_raw;
251
252 for i in 0..j {
254 h[i][j - 1] = dot(&w, &v[i]);
255 for l in 0..n {
256 w[l] -= h[i][j - 1] * v[i][l];
257 }
258 }
259 h[j][j - 1] = norm2(&w);
260
261 if h[j][j - 1] > 1e-15 {
262 let inv = 1.0 / h[j][j - 1];
263 for l in 0..n {
264 v[j][l] = w[l] * inv;
265 }
266 j_end = j + 1;
267 } else {
268 j_end = j + 1;
269 break;
270 }
271
272 if total_mv >= self.config.max_iter {
273 j_end = j + 1;
274 break;
275 }
276 }
277
278 let krylov_size = (j_end - 1).max(1).min(h[0].len());
279
280 let mut g = vec![0.0f64; j_end];
282 g[0] = r2_norm;
283
284 let cols = krylov_size.min(h[0].len());
285 let y = solve_least_squares_hessenberg(&h, &g, cols)?;
286
287 for j in 0..y.len().min(v.len()) {
289 for i in 0..n {
290 x[i] += y[j] * v[j][i];
291 }
292 }
293
294 last_krylov = v[..j_end].to_vec();
296
297 if total_mv >= self.config.max_iter {
298 break;
299 }
300 }
301
302 let ax_fin = matvec(&x);
304 total_mv += 1;
305 let r_fin: Vec<f64> = b
306 .iter()
307 .zip(ax_fin.iter())
308 .map(|(bi, axi)| bi - axi)
309 .collect();
310 let r_fin_norm = norm2(&r_fin);
311 residual_history.push(r_fin_norm);
312
313 let new_aug = extract_augmentation(&last_krylov, k_aug, n);
314
315 Ok(AugmentedKrylovResult {
316 x,
317 residual_norm: r_fin_norm,
318 iterations: total_mv,
319 converged: r_fin_norm <= abs_tol,
320 residual_history,
321 new_augmentation: new_aug,
322 })
323 }
324}
325
326pub(crate) fn solve_small_spd(a: &[Vec<f64>], b: &[f64], k: usize) -> Vec<f64> {
329 if k == 0 {
330 return Vec::new();
331 }
332 if k == 1 {
333 let diag = a[0][0];
334 return vec![if diag.abs() > 1e-300 {
335 b[0] / diag
336 } else {
337 0.0
338 }];
339 }
340
341 let mut l = vec![vec![0.0f64; k]; k];
343 let mut ok = true;
344 'chol: for i in 0..k {
345 for j in 0..=i {
346 let mut sum = a[i][j];
347 for p in 0..j {
348 sum -= l[i][p] * l[j][p];
349 }
350 if i == j {
351 if sum < 1e-300 {
352 ok = false;
353 break 'chol;
354 }
355 l[i][j] = sum.sqrt();
356 } else if l[j][j].abs() > 1e-300 {
357 l[i][j] = sum / l[j][j];
358 } else {
359 ok = false;
360 break 'chol;
361 }
362 }
363 }
364
365 if ok {
366 let mut y = vec![0.0f64; k];
368 for i in 0..k {
369 let mut s = b[i];
370 for j in 0..i {
371 s -= l[i][j] * y[j];
372 }
373 y[i] = if l[i][i].abs() > 1e-300 {
374 s / l[i][i]
375 } else {
376 0.0
377 };
378 }
379 let mut x = vec![0.0f64; k];
381 for i in (0..k).rev() {
382 let mut s = y[i];
383 for j in (i + 1)..k {
384 s -= l[j][i] * x[j];
385 }
386 x[i] = if l[i][i].abs() > 1e-300 {
387 s / l[i][i]
388 } else {
389 0.0
390 };
391 }
392 x
393 } else {
394 (0..k)
396 .map(|i| {
397 if a[i][i].abs() > 1e-300 {
398 b[i] / a[i][i]
399 } else {
400 0.0
401 }
402 })
403 .collect()
404 }
405}
406
407fn extract_augmentation(krylov: &[Vec<f64>], k_aug: usize, _n: usize) -> Vec<Vec<f64>> {
410 if krylov.is_empty() || k_aug == 0 {
411 return Vec::new();
412 }
413 let m = krylov.len();
414 let take = k_aug.min(m);
415 let mut new_vecs: Vec<Vec<f64>> = krylov[..take].to_vec();
417 gram_schmidt_mgs(&mut new_vecs);
418 new_vecs.retain(|vi| norm2(vi) > 0.5);
419 new_vecs
420}
421
422#[cfg(test)]
427mod tests {
428 use super::*;
429
430 fn diag_mv(diag: Vec<f64>) -> impl Fn(&[f64]) -> Vec<f64> {
431 move |x: &[f64]| x.iter().zip(diag.iter()).map(|(xi, di)| xi * di).collect()
432 }
433
434 #[test]
435 fn test_augmented_krylov_no_augmentation() {
436 let n = 8;
438 let diag: Vec<f64> = (1..=n).map(|i| i as f64).collect();
439 let b = vec![1.0f64; n];
440
441 let solver = AugmentedKrylov::new(AugmentedKrylovConfig {
442 krylov_dim: 6,
443 tol: 1e-12,
444 max_iter: 300,
445 max_cycles: 20,
446 });
447
448 let result = solver
449 .solve(diag_mv(diag.clone()), &b, None, &[])
450 .expect("augmented krylov solve failed");
451
452 assert!(
453 result.converged,
454 "should converge without augmentation: residual = {:.3e}",
455 result.residual_norm
456 );
457 }
458
459 #[test]
460 fn test_augmented_krylov_with_augmentation() {
461 let n = 10;
463 let diag: Vec<f64> = (1..=n).map(|i| i as f64).collect();
464 let b = vec![1.0f64; n];
465
466 let aug = vec![
468 {
469 let mut v = vec![0.0f64; n];
470 v[0] = 1.0;
471 v
472 },
473 {
474 let mut v = vec![0.0f64; n];
475 v[1] = 1.0;
476 v
477 },
478 ];
479
480 let solver = AugmentedKrylov::new(AugmentedKrylovConfig {
481 krylov_dim: 8,
482 tol: 1e-12,
483 max_iter: 300,
484 max_cycles: 30,
485 });
486
487 let result = solver
488 .solve(diag_mv(diag), &b, None, &aug)
489 .expect("augmented krylov with augmentation failed");
490
491 assert!(
492 result.converged,
493 "should converge with augmentation: residual = {:.3e}",
494 result.residual_norm
495 );
496 }
497
498 #[test]
499 fn test_augmented_result_new_augmentation_populated() {
500 let n = 6;
501 let diag: Vec<f64> = (1..=n).map(|i| i as f64).collect();
502 let b = vec![1.0f64; n];
503
504 let aug = vec![{
505 let mut v = vec![0.0f64; n];
506 v[0] = 1.0;
507 v
508 }];
509
510 let solver = AugmentedKrylov::with_defaults();
511 let result = solver
512 .solve(diag_mv(diag), &b, None, &aug)
513 .expect("solve failed");
514
515 assert!(result.converged || result.residual_norm < 1e-8);
517 }
518
519 #[test]
520 fn test_augmented_config_default() {
521 let cfg = AugmentedKrylovConfig::default();
522 assert_eq!(cfg.krylov_dim, 20);
523 assert!(cfg.tol > 0.0);
524 assert!(cfg.max_iter > 0);
525 }
526}