1use crate::error::{SparseError, SparseResult};
9use crate::gpu::construction::GpuCsrMatrix;
10use crate::gpu::spmv::GpuSpMvBackend;
11
12#[non_exhaustive]
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
22pub enum GpuSolverBackend {
23 #[default]
25 Cpu,
26 WebGpu,
28}
29
30impl From<GpuSolverBackend> for GpuSpMvBackend {
31 fn from(b: GpuSolverBackend) -> Self {
32 match b {
33 GpuSolverBackend::Cpu => GpuSpMvBackend::Cpu,
34 GpuSolverBackend::WebGpu => GpuSpMvBackend::WebGpu,
35 }
36 }
37}
38
39#[derive(Debug, Clone)]
41pub struct GpuSolverConfig {
42 pub max_iter: usize,
44 pub tol: f64,
46 pub precond: bool,
48 pub backend: GpuSolverBackend,
50}
51
52impl Default for GpuSolverConfig {
53 fn default() -> Self {
54 Self {
55 max_iter: 1000,
56 tol: 1e-8,
57 precond: true,
58 backend: GpuSolverBackend::Cpu,
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct SolverResult {
66 pub x: Vec<f64>,
68 pub residual_norm: f64,
70 pub n_iter: usize,
72 pub converged: bool,
74}
75
76#[inline]
81fn dot(a: &[f64], b: &[f64]) -> f64 {
82 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
83}
84
85#[inline]
86fn norm2(v: &[f64]) -> f64 {
87 dot(v, v).sqrt()
88}
89
90#[inline]
91fn axpy(alpha: f64, x: &[f64], y: &mut [f64]) {
92 for (yi, &xi) in y.iter_mut().zip(x.iter()) {
93 *yi += alpha * xi;
94 }
95}
96
97#[inline]
99fn axpby(alpha: f64, x: &[f64], beta: f64, y: &[f64], z: &mut [f64]) {
100 for ((zi, &xi), &yi) in z.iter_mut().zip(x.iter()).zip(y.iter()) {
101 *zi = alpha * xi + beta * yi;
102 }
103}
104
105fn jacobi_diag(matrix: &GpuCsrMatrix) -> Vec<f64> {
111 let n = matrix.n_rows;
112 let mut diag = vec![1.0_f64; n];
113 for row in 0..n {
114 let start = matrix.row_ptr[row];
115 let end = matrix.row_ptr[row + 1];
116 for k in start..end {
117 if matrix.col_idx[k] == row {
118 let d = matrix.values[k];
119 if d.abs() > f64::EPSILON {
120 diag[row] = d;
121 }
122 }
123 }
124 }
125 diag
126}
127
128fn apply_jacobi(diag: &[f64], r: &[f64], z: &mut [f64]) {
130 for ((zi, &ri), &di) in z.iter_mut().zip(r.iter()).zip(diag.iter()) {
131 *zi = ri / di;
132 }
133}
134
135pub fn cg_csr(
150 matrix: &GpuCsrMatrix,
151 b: &[f64],
152 x0: Option<&[f64]>,
153 config: &GpuSolverConfig,
154) -> SparseResult<SolverResult> {
155 let n = matrix.n_rows;
156 if matrix.n_cols != n {
157 return Err(SparseError::ComputationError(
158 "CG requires a square matrix".to_string(),
159 ));
160 }
161 if b.len() != n {
162 return Err(SparseError::DimensionMismatch {
163 expected: n,
164 found: b.len(),
165 });
166 }
167 if let Some(x) = x0 {
168 if x.len() != n {
169 return Err(SparseError::DimensionMismatch {
170 expected: n,
171 found: x.len(),
172 });
173 }
174 }
175
176 let diag = if config.precond {
177 jacobi_diag(matrix)
178 } else {
179 vec![1.0; n]
180 };
181
182 let mut x = match x0 {
184 Some(x0) => x0.to_vec(),
185 None => vec![0.0; n],
186 };
187
188 let ax = matrix.spmv(&x)?;
190 let mut r: Vec<f64> = b.iter().zip(ax.iter()).map(|(bi, &ai)| bi - ai).collect();
191
192 let mut z = vec![0.0; n];
194 apply_jacobi(&diag, &r, &mut z);
195
196 let mut p = z.clone();
198
199 let mut rz = dot(&r, &z);
200 let b_norm = norm2(b);
201 let tol_abs = if b_norm > 0.0 {
202 config.tol * b_norm
203 } else {
204 config.tol
205 };
206
207 let mut iter = 0usize;
208 let mut converged = false;
209
210 while iter < config.max_iter {
211 let ap = matrix.spmv(&p)?;
212 let pap = dot(&p, &ap);
213 if pap.abs() < f64::MIN_POSITIVE {
214 break; }
216 let alpha = rz / pap;
217
218 axpy(alpha, &p, &mut x);
220
221 axpy(-alpha, &ap, &mut r);
223
224 let r_norm = norm2(&r);
225 iter += 1;
226
227 if r_norm <= tol_abs {
228 converged = true;
229 break;
230 }
231
232 apply_jacobi(&diag, &r, &mut z);
234
235 let rz_new = dot(&r, &z);
236 let beta = rz_new / rz;
237 rz = rz_new;
238
239 let p_old = p.clone();
241 axpby(1.0, &z, beta, &p_old, &mut p);
242 }
243
244 let residual = matrix.spmv(&x)?;
245 let res_norm = norm2(
246 &b.iter()
247 .zip(residual.iter())
248 .map(|(bi, &ri)| bi - ri)
249 .collect::<Vec<_>>(),
250 );
251
252 Ok(SolverResult {
253 x,
254 residual_norm: res_norm,
255 n_iter: iter,
256 converged,
257 })
258}
259
260pub fn bicgstab_csr(
275 matrix: &GpuCsrMatrix,
276 b: &[f64],
277 x0: Option<&[f64]>,
278 config: &GpuSolverConfig,
279) -> SparseResult<SolverResult> {
280 let n = matrix.n_rows;
281 if matrix.n_cols != n {
282 return Err(SparseError::ComputationError(
283 "BiCGSTAB requires a square matrix".to_string(),
284 ));
285 }
286 if b.len() != n {
287 return Err(SparseError::DimensionMismatch {
288 expected: n,
289 found: b.len(),
290 });
291 }
292
293 let diag = if config.precond {
294 jacobi_diag(matrix)
295 } else {
296 vec![1.0; n]
297 };
298
299 let mut x = match x0 {
300 Some(x0) => x0.to_vec(),
301 None => vec![0.0; n],
302 };
303
304 let ax0 = matrix.spmv(&x)?;
306 let mut r: Vec<f64> = b.iter().zip(ax0.iter()).map(|(bi, &ai)| bi - ai).collect();
307
308 let r_hat = r.clone();
310
311 let b_norm = norm2(b);
312 let tol_abs = if b_norm > 0.0 {
313 config.tol * b_norm
314 } else {
315 config.tol
316 };
317
318 let mut p = r.clone();
320
321 let mut rho = dot(&r_hat, &r);
322 #[allow(unused_assignments)]
325 let mut omega = 1.0_f64;
326
327 let mut p_hat = vec![0.0; n];
329 let mut s_hat = vec![0.0; n];
330
331 let mut iter = 0usize;
332 let mut converged = false;
333
334 while iter < config.max_iter {
335 apply_jacobi(&diag, &p, &mut p_hat);
337
338 let v = matrix.spmv(&p_hat)?;
339 let rtv = dot(&r_hat, &v);
340 if rtv.abs() < f64::MIN_POSITIVE {
341 break; }
343 let alpha = rho / rtv;
344
345 let mut s: Vec<f64> = r
347 .iter()
348 .zip(v.iter())
349 .map(|(&ri, &vi)| ri - alpha * vi)
350 .collect();
351
352 let s_norm = norm2(&s);
353 if s_norm <= tol_abs {
354 axpy(alpha, &p_hat, &mut x);
355 iter += 1;
356 converged = true;
357 break;
358 }
359
360 apply_jacobi(&diag, &s, &mut s_hat);
362
363 let t = matrix.spmv(&s_hat)?;
364 let tt = dot(&t, &t);
365 omega = if tt > f64::MIN_POSITIVE {
366 dot(&t, &s) / tt
367 } else {
368 break;
369 };
370
371 axpy(alpha, &p_hat, &mut x);
373 axpy(omega, &s_hat, &mut x);
374
375 for ((ri, &si), &ti) in r.iter_mut().zip(s.iter()).zip(t.iter()) {
377 *ri = si - omega * ti;
378 }
379
380 let r_norm = norm2(&r);
381 iter += 1;
382 if r_norm <= tol_abs {
383 converged = true;
384 break;
385 }
386
387 let rho_new = dot(&r_hat, &r);
389 if rho_new.abs() < f64::MIN_POSITIVE {
390 break;
391 }
392 let beta = (rho_new / rho) * (alpha / omega);
393 rho = rho_new;
394
395 for ((pi, &ri), &vi) in p.iter_mut().zip(r.iter()).zip(v.iter()) {
397 *pi = ri + beta * (*pi - omega * vi);
398 }
399 }
400
401 let residual = matrix.spmv(&x)?;
402 let res_norm = norm2(
403 &b.iter()
404 .zip(residual.iter())
405 .map(|(bi, &ri)| bi - ri)
406 .collect::<Vec<_>>(),
407 );
408
409 Ok(SolverResult {
410 x,
411 residual_norm: res_norm,
412 n_iter: iter,
413 converged,
414 })
415}
416
417#[cfg(test)]
422mod tests {
423 use super::*;
424 use crate::gpu::construction::{GpuCooMatrix, GpuCsrMatrix};
425
426 fn tridiag_spd(n: usize) -> GpuCsrMatrix {
428 let mut coo = GpuCooMatrix::new(n, n);
429 for i in 0..n {
430 coo.push(i, i, 4.0);
431 if i > 0 {
432 coo.push(i, i - 1, -1.0);
433 coo.push(i - 1, i, -1.0);
434 }
435 }
436 coo.to_csr()
437 }
438
439 #[test]
440 fn test_cg_spd_system() {
441 let n = 5;
442 let mat = tridiag_spd(n);
443 let x_true = vec![1.0, 2.0, 3.0, 4.0, 5.0];
444 let b = mat.spmv(&x_true).expect("spmv failed");
445
446 let config = GpuSolverConfig::default();
447 let result = cg_csr(&mat, &b, None, &config).expect("CG failed");
448 assert!(result.converged, "CG did not converge");
449 assert!(result.residual_norm < 1e-6);
450 for (xi, &xt) in result.x.iter().zip(x_true.iter()) {
451 assert!((xi - xt).abs() < 1e-6, "x[i]={xi} expected {xt}");
452 }
453 }
454
455 #[test]
456 fn test_bicgstab_general() {
457 let mut coo = GpuCooMatrix::new(4, 4);
459 coo.push(0, 0, 4.0);
460 coo.push(0, 1, 1.0);
461 coo.push(1, 0, 2.0);
462 coo.push(1, 1, 5.0);
463 coo.push(1, 2, 1.0);
464 coo.push(2, 1, 2.0);
465 coo.push(2, 2, 6.0);
466 coo.push(2, 3, 1.0);
467 coo.push(3, 2, 2.0);
468 coo.push(3, 3, 7.0);
469 let mat = coo.to_csr();
470
471 let x_true = vec![1.0, 2.0, 3.0, 4.0];
472 let b = mat.spmv(&x_true).expect("spmv failed");
473
474 let config = GpuSolverConfig::default();
475 let result = bicgstab_csr(&mat, &b, None, &config).expect("BiCGSTAB failed");
476 assert!(result.converged, "BiCGSTAB did not converge");
477 assert!(result.residual_norm < 1e-6);
478 }
479
480 #[test]
481 fn test_cg_with_precond() {
482 let n = 10;
484 let mat = tridiag_spd(n);
485 let b = vec![1.0; n];
486
487 let config_precond = GpuSolverConfig {
488 precond: true,
489 ..Default::default()
490 };
491 let config_nopc = GpuSolverConfig {
492 precond: false,
493 ..Default::default()
494 };
495
496 let r_precond = cg_csr(&mat, &b, None, &config_precond).expect("CG failed");
497 let r_nopc = cg_csr(&mat, &b, None, &config_nopc).expect("CG failed");
498
499 assert!(r_precond.converged);
500 assert!(r_nopc.converged);
501 assert!(r_precond.n_iter <= r_nopc.n_iter + 5); }
504
505 #[test]
506 fn test_cg_with_initial_guess() {
507 let n = 5;
508 let mat = tridiag_spd(n);
509 let x_true = vec![1.0; n];
510 let b = mat.spmv(&x_true).expect("spmv failed");
511
512 let x0 = vec![0.9; n];
514 let config = GpuSolverConfig::default();
515 let result = cg_csr(&mat, &b, Some(&x0), &config).expect("CG failed");
516 assert!(result.converged);
517 }
518
519 #[test]
520 fn test_solver_dimension_mismatch() {
521 let n = 3;
522 let mat = tridiag_spd(n);
523 let b_wrong = vec![1.0; n + 1];
524 let config = GpuSolverConfig::default();
525 assert!(cg_csr(&mat, &b_wrong, None, &config).is_err());
526 assert!(bicgstab_csr(&mat, &b_wrong, None, &config).is_err());
527 }
528}