1use scivex_core::Float;
8use scivex_core::linalg::CsrMatrix;
9use scivex_core::tensor::Tensor;
10
11use crate::error::{OptimError, Result};
12
13#[cfg_attr(
19 feature = "serde-support",
20 derive(serde::Serialize, serde::Deserialize)
21)]
22#[derive(Debug, Clone)]
23pub struct SparseSolveResult<T: Float> {
24 pub x: Vec<T>,
26 pub iterations: usize,
28 pub residual_norm: T,
30 pub converged: bool,
32}
33
34#[inline]
40fn vec_dot<T: Float>(a: &[T], b: &[T]) -> T {
41 a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
42}
43
44#[inline]
46fn vec_norm<T: Float>(v: &[T]) -> T {
47 vec_dot(v, v).sqrt()
48}
49
50#[inline]
52fn vec_axpy<T: Float>(alpha: T, x: &[T], y: &mut [T]) {
53 for (yi, &xi) in y.iter_mut().zip(x.iter()) {
54 *yi += alpha * xi;
55 }
56}
57
58fn compute_residual<T: Float>(a: &CsrMatrix<T>, x: &[T], b: &[T]) -> Result<Vec<T>> {
60 let x_tensor = Tensor::from_vec(x.to_vec(), vec![x.len()])?;
61 let ax = a.matvec(&x_tensor)?;
62 let ax_slice = ax.as_slice();
63 let r: Vec<T> = b
64 .iter()
65 .zip(ax_slice.iter())
66 .map(|(&bi, &ai)| bi - ai)
67 .collect();
68 Ok(r)
69}
70
71fn sparse_matvec<T: Float>(a: &CsrMatrix<T>, x: &[T]) -> Result<Vec<T>> {
73 let x_tensor = Tensor::from_vec(x.to_vec(), vec![x.len()])?;
74 let result = a.matvec(&x_tensor)?;
75 Ok(result.as_slice().to_vec())
76}
77
78pub fn conjugate_gradient<T: Float>(
103 a: &CsrMatrix<T>,
104 b: &[T],
105 x0: Option<&[T]>,
106 max_iter: usize,
107 tol: T,
108) -> Result<SparseSolveResult<T>> {
109 let n = a.nrows();
110 if a.ncols() != n {
111 return Err(OptimError::InvalidParameter {
112 name: "a",
113 reason: "matrix must be square",
114 });
115 }
116 if b.len() != n {
117 return Err(OptimError::InvalidParameter {
118 name: "b",
119 reason: "length must match matrix dimension",
120 });
121 }
122 if x0.is_some_and(|x0v| x0v.len() != n) {
123 return Err(OptimError::InvalidParameter {
124 name: "x0",
125 reason: "length must match matrix dimension",
126 });
127 }
128
129 let b_norm = vec_norm(b);
130
131 let mut x = x0.map_or_else(|| vec![T::zero(); n], <[T]>::to_vec);
133
134 let mut r = compute_residual(a, &x, b)?;
136 let mut p = r.clone();
138 let mut rr = vec_dot(&r, &r);
140
141 let threshold = tol * b_norm;
142
143 for k in 0..max_iter {
144 let r_norm = rr.sqrt();
145 if r_norm < threshold || (b_norm == T::zero() && r_norm < tol) {
146 return Ok(SparseSolveResult {
147 x,
148 iterations: k,
149 residual_norm: r_norm,
150 converged: true,
151 });
152 }
153
154 let ap = sparse_matvec(a, &p)?;
156 let p_ap = vec_dot(&p, &ap);
157
158 if p_ap == T::zero() {
159 return Ok(SparseSolveResult {
161 x,
162 iterations: k,
163 residual_norm: r_norm,
164 converged: false,
165 });
166 }
167
168 let alpha = rr / p_ap;
169
170 vec_axpy(alpha, &p, &mut x);
172
173 vec_axpy(-alpha, &ap, &mut r);
175
176 let rr_new = vec_dot(&r, &r);
177 let beta = rr_new / rr;
178
179 for (pi, &ri) in p.iter_mut().zip(r.iter()) {
181 *pi = ri + beta * *pi;
182 }
183
184 rr = rr_new;
185 }
186
187 let final_norm = rr.sqrt();
188 Ok(SparseSolveResult {
189 x,
190 iterations: max_iter,
191 residual_norm: final_norm,
192 converged: final_norm < threshold,
193 })
194}
195
196#[allow(clippy::too_many_lines)]
213pub fn bicgstab<T: Float>(
214 a: &CsrMatrix<T>,
215 b: &[T],
216 x0: Option<&[T]>,
217 max_iter: usize,
218 tol: T,
219) -> Result<SparseSolveResult<T>> {
220 let n = a.nrows();
221 if a.ncols() != n {
222 return Err(OptimError::InvalidParameter {
223 name: "a",
224 reason: "matrix must be square",
225 });
226 }
227 if b.len() != n {
228 return Err(OptimError::InvalidParameter {
229 name: "b",
230 reason: "length must match matrix dimension",
231 });
232 }
233 if x0.is_some_and(|x0v| x0v.len() != n) {
234 return Err(OptimError::InvalidParameter {
235 name: "x0",
236 reason: "length must match matrix dimension",
237 });
238 }
239
240 let b_norm = vec_norm(b);
241 let threshold = tol * b_norm;
242
243 let mut x = match x0 {
245 Some(v) => v.to_vec(),
246 None => vec![T::zero(); n],
247 };
248
249 let mut r = compute_residual(a, &x, b)?;
251 let r_hat = r.clone();
253
254 let mut rho = T::one();
255 let mut alpha = T::one();
256 let mut omega = T::one();
257
258 let mut v = vec![T::zero(); n];
259 let mut p = vec![T::zero(); n];
260
261 for k in 0..max_iter {
262 let r_norm = vec_norm(&r);
263 if r_norm < threshold || (b_norm == T::zero() && r_norm < tol) {
264 return Ok(SparseSolveResult {
265 x,
266 iterations: k,
267 residual_norm: r_norm,
268 converged: true,
269 });
270 }
271
272 let rho_new = vec_dot(&r_hat, &r);
273
274 if rho_new == T::zero() {
275 return Ok(SparseSolveResult {
277 x,
278 iterations: k,
279 residual_norm: r_norm,
280 converged: false,
281 });
282 }
283
284 let beta = (rho_new / rho) * (alpha / omega);
285
286 for ((pi, &ri), &vi) in p.iter_mut().zip(r.iter()).zip(v.iter()) {
288 *pi = ri + beta * (*pi - omega * vi);
289 }
290
291 v = sparse_matvec(a, &p)?;
293
294 let r_hat_v = vec_dot(&r_hat, &v);
295 if r_hat_v == T::zero() {
296 return Ok(SparseSolveResult {
297 x,
298 iterations: k,
299 residual_norm: r_norm,
300 converged: false,
301 });
302 }
303 alpha = rho_new / r_hat_v;
304
305 let mut s = r.clone();
307 vec_axpy(-alpha, &v, &mut s);
308
309 let s_norm = vec_norm(&s);
310 if s_norm < threshold {
311 vec_axpy(alpha, &p, &mut x);
313 return Ok(SparseSolveResult {
314 x,
315 iterations: k + 1,
316 residual_norm: s_norm,
317 converged: true,
318 });
319 }
320
321 let t = sparse_matvec(a, &s)?;
323
324 let t_t = vec_dot(&t, &t);
325 omega = if t_t == T::zero() {
326 T::zero()
327 } else {
328 vec_dot(&t, &s) / t_t
329 };
330
331 vec_axpy(alpha, &p, &mut x);
333 vec_axpy(omega, &s, &mut x);
334
335 r = s;
337 vec_axpy(-omega, &t, &mut r);
338
339 rho = rho_new;
340
341 if omega == T::zero() {
342 let final_norm = vec_norm(&r);
344 return Ok(SparseSolveResult {
345 x,
346 iterations: k + 1,
347 residual_norm: final_norm,
348 converged: false,
349 });
350 }
351 }
352
353 let final_norm = vec_norm(&r);
354 Ok(SparseSolveResult {
355 x,
356 iterations: max_iter,
357 residual_norm: final_norm,
358 converged: final_norm < threshold,
359 })
360}
361
362#[cfg_attr(
372 feature = "serde-support",
373 derive(serde::Serialize, serde::Deserialize)
374)]
375#[derive(Debug, Clone)]
376pub struct JacobiPreconditioner<T: Float> {
377 inv_diag: Vec<T>,
378}
379
380impl<T: Float> JacobiPreconditioner<T> {
381 pub fn new(a: &CsrMatrix<T>) -> Self {
385 let n = a.nrows().min(a.ncols());
386 let mut inv_diag = Vec::with_capacity(n);
387 for i in 0..n {
388 let d = a.get(i, i).copied().unwrap_or(T::zero());
389 if d == T::zero() {
390 inv_diag.push(T::one());
391 } else {
392 inv_diag.push(T::one() / d);
393 }
394 }
395 Self { inv_diag }
396 }
397
398 pub fn apply(&self, r: &[T]) -> Vec<T> {
401 r.iter()
402 .zip(self.inv_diag.iter())
403 .map(|(&ri, &di)| ri * di)
404 .collect()
405 }
406}
407
408pub fn preconditioned_cg<T: Float>(
418 a: &CsrMatrix<T>,
419 b: &[T],
420 preconditioner: &JacobiPreconditioner<T>,
421 x0: Option<&[T]>,
422 max_iter: usize,
423 tol: T,
424) -> Result<SparseSolveResult<T>> {
425 let n = a.nrows();
426 if a.ncols() != n {
427 return Err(OptimError::InvalidParameter {
428 name: "a",
429 reason: "matrix must be square",
430 });
431 }
432 if b.len() != n {
433 return Err(OptimError::InvalidParameter {
434 name: "b",
435 reason: "length must match matrix dimension",
436 });
437 }
438 if x0.is_some_and(|x0v| x0v.len() != n) {
439 return Err(OptimError::InvalidParameter {
440 name: "x0",
441 reason: "length must match matrix dimension",
442 });
443 }
444
445 let b_norm = vec_norm(b);
446 let threshold = tol * b_norm;
447
448 let mut x = match x0 {
449 Some(v) => v.to_vec(),
450 None => vec![T::zero(); n],
451 };
452
453 let mut r = compute_residual(a, &x, b)?;
455 let mut z = preconditioner.apply(&r);
457 let mut p = z.clone();
459 let mut rz = vec_dot(&r, &z);
461
462 for k in 0..max_iter {
463 let r_norm = vec_norm(&r);
464 if r_norm < threshold || (b_norm == T::zero() && r_norm < tol) {
465 return Ok(SparseSolveResult {
466 x,
467 iterations: k,
468 residual_norm: r_norm,
469 converged: true,
470 });
471 }
472
473 let ap = sparse_matvec(a, &p)?;
475 let p_ap = vec_dot(&p, &ap);
476
477 if p_ap == T::zero() {
478 return Ok(SparseSolveResult {
479 x,
480 iterations: k,
481 residual_norm: r_norm,
482 converged: false,
483 });
484 }
485
486 let alpha = rz / p_ap;
487
488 vec_axpy(alpha, &p, &mut x);
490
491 vec_axpy(-alpha, &ap, &mut r);
493
494 z = preconditioner.apply(&r);
496
497 let rz_new = vec_dot(&r, &z);
498 let beta = rz_new / rz;
499
500 for (pi, &zi) in p.iter_mut().zip(z.iter()) {
502 *pi = zi + beta * *pi;
503 }
504
505 rz = rz_new;
506 }
507
508 let final_norm = vec_norm(&r);
509 Ok(SparseSolveResult {
510 x,
511 iterations: max_iter,
512 residual_norm: final_norm,
513 converged: final_norm < threshold,
514 })
515}
516
517#[cfg(test)]
522#[allow(clippy::float_cmp)]
523mod tests {
524 use super::*;
525 use scivex_core::linalg::CsrMatrix;
526
527 fn spd_3x3() -> CsrMatrix<f64> {
532 CsrMatrix::from_triplets(
533 3,
534 3,
535 vec![0, 0, 1, 1, 1, 2, 2],
536 vec![0, 1, 0, 1, 2, 1, 2],
537 vec![4.0, 1.0, 1.0, 3.0, 1.0, 1.0, 4.0],
538 )
539 .unwrap()
540 }
541
542 #[test]
543 fn test_cg_simple_3x3() {
544 let a = spd_3x3();
545 let b = [6.0, 10.0, 14.0];
547 let result = conjugate_gradient(&a, &b, None, 100, 1e-10).unwrap();
548 assert!(result.converged);
549 assert!((result.x[0] - 1.0).abs() < 1e-8);
550 assert!((result.x[1] - 2.0).abs() < 1e-8);
551 assert!((result.x[2] - 3.0).abs() < 1e-8);
552 }
553
554 #[test]
555 fn test_cg_diagonal_system() {
556 let n = 10;
559 let mut rows = Vec::new();
560 let mut cols = Vec::new();
561 let mut vals = Vec::new();
562 let mut b = vec![0.0; n];
563 let expected: Vec<f64> = (1..=n).map(|i| i as f64).collect();
564
565 for i in 0..n {
566 rows.push(i);
567 cols.push(i);
568 vals.push((i + 1) as f64); b[i] = (i + 1) as f64 * expected[i]; }
571
572 let a = CsrMatrix::from_triplets(n, n, rows, cols, vals).unwrap();
573 let result = conjugate_gradient(&a, &b, None, 100, 1e-12).unwrap();
574 assert!(result.converged);
575 for (i, (xi, ei)) in result.x.iter().zip(expected.iter()).enumerate() {
576 assert!((*xi - *ei).abs() < 1e-8, "x[{i}] = {xi}, expected {ei}",);
577 }
578 }
579
580 #[test]
581 fn test_bicgstab_nonsymmetric() {
582 let a = CsrMatrix::from_triplets(
587 3,
588 3,
589 vec![0, 0, 1, 1, 2, 2],
590 vec![0, 1, 1, 2, 0, 2],
591 vec![3.0, 1.0, 4.0, 2.0, 1.0, 5.0],
592 )
593 .unwrap();
594 let b = [5.0, 14.0, 16.0];
596 let result = bicgstab(&a, &b, None, 100, 1e-10).unwrap();
597 assert!(result.converged, "BiCGSTAB did not converge");
598 assert!((result.x[0] - 1.0).abs() < 1e-6);
599 assert!((result.x[1] - 2.0).abs() < 1e-6);
600 assert!((result.x[2] - 3.0).abs() < 1e-6);
601 }
602
603 #[test]
604 fn test_bicgstab_on_spd() {
605 let a = spd_3x3();
607 let b = [6.0, 10.0, 14.0];
608 let result = bicgstab(&a, &b, None, 100, 1e-10).unwrap();
609 assert!(result.converged);
610 assert!((result.x[0] - 1.0).abs() < 1e-6);
611 assert!((result.x[1] - 2.0).abs() < 1e-6);
612 assert!((result.x[2] - 3.0).abs() < 1e-6);
613 }
614
615 #[test]
616 fn test_jacobi_preconditioner() {
617 let a = spd_3x3();
618 let prec = JacobiPreconditioner::new(&a);
619 assert!((prec.inv_diag[0] - 0.25).abs() < 1e-15);
621 assert!((prec.inv_diag[1] - 1.0 / 3.0).abs() < 1e-15);
622 assert!((prec.inv_diag[2] - 0.25).abs() < 1e-15);
623
624 let r = [4.0, 3.0, 8.0];
625 let z = prec.apply(&r);
626 assert!((z[0] - 1.0).abs() < 1e-15);
627 assert!((z[1] - 1.0).abs() < 1e-15);
628 assert!((z[2] - 2.0).abs() < 1e-15);
629 }
630
631 #[test]
632 fn test_preconditioned_cg_converges() {
633 let n = 20;
636 let mut rows = Vec::new();
637 let mut cols = Vec::new();
638 let mut vals = Vec::new();
639 let mut b = vec![0.0; n];
640
641 for (i, bi) in b.iter_mut().enumerate() {
642 rows.push(i);
643 cols.push(i);
644 let d = ((i + 1) * (i + 1)) as f64;
646 vals.push(d);
647 *bi = d; }
649
650 let a = CsrMatrix::from_triplets(n, n, rows, cols, vals).unwrap();
651 let prec = JacobiPreconditioner::new(&a);
652
653 let result_pcg = preconditioned_cg(&a, &b, &prec, None, 100, 1e-10).unwrap();
654 let result_cg = conjugate_gradient(&a, &b, None, 100, 1e-10).unwrap();
655
656 assert!(result_pcg.converged);
657 assert!(result_cg.converged);
658 assert!(
661 result_pcg.iterations <= result_cg.iterations,
662 "PCG iters {} > CG iters {}",
663 result_pcg.iterations,
664 result_cg.iterations
665 );
666 }
667
668 #[test]
669 fn test_cg_custom_initial_guess() {
670 let a = spd_3x3();
671 let b = [6.0, 10.0, 14.0];
672 let x0 = [0.9, 2.1, 2.9];
674 let result = conjugate_gradient(&a, &b, Some(&x0), 100, 1e-10).unwrap();
675 assert!(result.converged);
676 assert!((result.x[0] - 1.0).abs() < 1e-8);
677 assert!((result.x[1] - 2.0).abs() < 1e-8);
678 assert!((result.x[2] - 3.0).abs() < 1e-8);
679 }
680
681 #[test]
682 fn test_convergence_failure_max_iter_too_small() {
683 let a = spd_3x3();
684 let b = [6.0, 10.0, 14.0];
686 let result = conjugate_gradient(&a, &b, None, 0, 1e-10).unwrap();
687 assert!(!result.converged);
688 assert_eq!(result.iterations, 0);
689 }
690
691 #[test]
692 fn test_dimension_mismatch_errors() {
693 let a = spd_3x3();
694
695 let b_short = [1.0, 2.0];
697 let err = conjugate_gradient(&a, &b_short, None, 10, 1e-10);
698 assert!(err.is_err());
699
700 let b = [6.0, 10.0, 14.0];
702 let x0_bad = [1.0, 2.0];
703 let err = conjugate_gradient(&a, &b, Some(&x0_bad), 10, 1e-10);
704 assert!(err.is_err());
705
706 let rect = CsrMatrix::from_triplets(2, 3, vec![0, 1], vec![0, 1], vec![1.0, 2.0]).unwrap();
708 let err = conjugate_gradient(&rect, &[1.0, 2.0], None, 10, 1e-10);
709 assert!(err.is_err());
710
711 let err = bicgstab(&a, &b_short, None, 10, 1e-10);
713 assert!(err.is_err());
714 }
715
716 #[test]
717 fn test_bicgstab_tridiagonal() {
718 let a = CsrMatrix::from_triplets(
727 4,
728 4,
729 vec![0, 0, 1, 1, 2, 2, 3],
730 vec![0, 1, 1, 2, 2, 3, 3],
731 vec![2.0, -1.0, 2.0, -1.0, 2.0, -1.0, 2.0],
732 )
733 .unwrap();
734 let b = [1.0, 1.0, 1.0, 2.0];
735 let result = bicgstab(&a, &b, None, 100, 1e-10).unwrap();
736 assert!(
737 result.converged,
738 "BiCGSTAB did not converge on tridiagonal system"
739 );
740 for i in 0..4 {
741 assert!(
742 (result.x[i] - 1.0).abs() < 1e-6,
743 "x[{i}] = {}, expected 1.0",
744 result.x[i]
745 );
746 }
747 }
748}