Skip to main content

scirs2_linalg/
mixed_cpu_gpu_solver.rs

1//! Mixed CPU/GPU linear system solver.
2//!
3//! Performs matrix factorization at low precision (simulating GPU dispatch)
4//! then applies iterative residual refinement on the CPU at higher precision.
5//!
6//! # Algorithm
7//!
8//! 1. Determine precision using the auto-precision policy.
9//! 2. Factorize the system `Ax = b` at the selected precision (f32 or f64).
10//! 3. Compute the residual `r = b - A x` in f64.
11//! 4. Solve a correction system `A delta = r` in f64 and apply `x += delta`.
12//! 5. Repeat until residual is smaller than `tol` or `refinement_steps` is
13//!    exhausted.
14//!
15//! # References
16//!
17//! - Higham (2002). "Accuracy and Stability of Numerical Algorithms." §12.4.
18//! - Demmel et al. (2006). "Error bounds from extra-precise iterative
19//!   refinement."
20
21use scirs2_core::ndarray::{Array1, Array2};
22
23use crate::auto_precision::{solve_f32, solve_f64, Precision, PrecisionPolicy};
24use crate::error::LinalgError;
25
26// ---------------------------------------------------------------------------
27// Public types
28// ---------------------------------------------------------------------------
29
30/// Statistics returned by [`MixedSolver::solve`].
31#[derive(Debug, Clone)]
32pub struct SolverStats {
33    /// Which precision was used for the initial solve.
34    pub precision_used: Precision,
35    /// Number of iterative refinement steps actually applied.
36    pub refinement_steps_done: usize,
37    /// 2-norm of the final residual `||b - A x||`.
38    pub final_residual: f64,
39}
40
41// ---------------------------------------------------------------------------
42// Solver
43// ---------------------------------------------------------------------------
44
45/// Mixed CPU/GPU linear system solver with iterative residual refinement.
46///
47/// The factorization step is dispatched at the precision recommended by
48/// [`PrecisionPolicy`].  Refinement steps always run in f64 on the CPU.
49pub struct MixedSolver {
50    /// Maximum number of iterative refinement steps.
51    refinement_steps: usize,
52    /// Convergence tolerance for the residual 2-norm.
53    tol: f64,
54    /// Precision policy for the initial factorization.
55    policy: PrecisionPolicy,
56}
57
58impl MixedSolver {
59    /// Create a new solver.
60    ///
61    /// # Arguments
62    ///
63    /// * `refinement_steps` — Maximum number of residual refinement iterations.
64    /// * `tol` — Stop refining when `||b - Ax|| < tol`.
65    pub fn new(refinement_steps: usize, tol: f64) -> Self {
66        Self {
67            refinement_steps,
68            tol,
69            policy: PrecisionPolicy::default(),
70        }
71    }
72
73    /// Create a new solver with an explicit precision policy.
74    pub fn with_policy(refinement_steps: usize, tol: f64, policy: PrecisionPolicy) -> Self {
75        Self {
76            refinement_steps,
77            tol,
78            policy,
79        }
80    }
81
82    /// Solve `Ax = b`.
83    ///
84    /// Returns the solution and solver statistics.
85    ///
86    /// # Errors
87    ///
88    /// Returns an error if the matrix is singular, dimensions are mismatched,
89    /// or an internal numerical failure occurs.
90    pub fn solve(
91        &self,
92        a: &Array2<f64>,
93        b: &Array1<f64>,
94    ) -> Result<(Array1<f64>, SolverStats), LinalgError> {
95        let n = a.nrows();
96        if a.ncols() != n {
97            return Err(LinalgError::DimensionError(format!(
98                "MixedSolver requires a square matrix, got {}x{}",
99                n,
100                a.ncols()
101            )));
102        }
103        if b.len() != n {
104            return Err(LinalgError::DimensionError(format!(
105                "rhs length {} does not match matrix dimension {}",
106                b.len(),
107                n
108            )));
109        }
110
111        // Step 1: select precision and solve initial system
112        let precision = crate::auto_precision::select_precision(a, &self.policy);
113        let mut x = match precision {
114            Precision::Single => solve_f32(a, b)?,
115            Precision::Double | Precision::Mixed => solve_f64(a, b)?,
116        };
117
118        // Step 2: iterative residual refinement
119        let mut steps_done = 0;
120        let mut final_res = residual_norm(a, b, &x);
121
122        for _ in 0..self.refinement_steps {
123            if final_res < self.tol {
124                break;
125            }
126            // Compute residual r = b - Ax  (in f64)
127            let r = compute_residual(a, b, &x);
128            // Solve A delta = r  (in f64)
129            let delta = solve_f64(a, &r)?;
130            // Apply correction
131            for i in 0..n {
132                x[i] += delta[i];
133            }
134            steps_done += 1;
135            final_res = residual_norm(a, b, &x);
136        }
137
138        Ok((
139            x,
140            SolverStats {
141                precision_used: precision,
142                refinement_steps_done: steps_done,
143                final_residual: final_res,
144            },
145        ))
146    }
147}
148
149// ---------------------------------------------------------------------------
150// Helpers
151// ---------------------------------------------------------------------------
152
153/// Compute the residual vector `r = b - A x`.
154fn compute_residual(a: &Array2<f64>, b: &Array1<f64>, x: &Array1<f64>) -> Array1<f64> {
155    let n = a.nrows();
156    let mut r = b.to_owned();
157    for i in 0..n {
158        let mut ax_i = 0.0;
159        for j in 0..n {
160            ax_i += a[[i, j]] * x[j];
161        }
162        r[i] -= ax_i;
163    }
164    r
165}
166
167/// Compute the 2-norm of the residual `||b - Ax||`.
168fn residual_norm(a: &Array2<f64>, b: &Array1<f64>, x: &Array1<f64>) -> f64 {
169    let r = compute_residual(a, b, x);
170    r.iter().map(|&ri| ri * ri).sum::<f64>().sqrt()
171}
172
173// ---------------------------------------------------------------------------
174// Tests
175// ---------------------------------------------------------------------------
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use crate::auto_precision::Precision;
181    use scirs2_core::ndarray::array;
182
183    #[test]
184    fn test_mixed_solver_well_conditioned() {
185        // Solve a simple 3x3 system
186        let a = array![[2.0_f64, 1.0, -1.0], [-3.0, -1.0, 2.0], [-2.0, 1.0, 2.0]];
187        let b = array![8.0_f64, -11.0, -3.0];
188        let solver = MixedSolver::new(3, 1e-12);
189        let (x, stats) = solver.solve(&a, &b).expect("should succeed");
190
191        assert!((x[0] - 2.0).abs() < 1e-8, "x[0]={}", x[0]);
192        assert!((x[1] - 3.0).abs() < 1e-8, "x[1]={}", x[1]);
193        assert!((x[2] - (-1.0)).abs() < 1e-8, "x[2]={}", x[2]);
194        assert!(
195            stats.final_residual < 1e-10,
196            "residual={}",
197            stats.final_residual
198        );
199    }
200
201    #[test]
202    fn test_mixed_solver_force_single_refines() {
203        // Force single precision so refinement is needed
204        let policy = PrecisionPolicy {
205            force: Some(Precision::Single),
206            ..Default::default()
207        };
208        let a = array![[4.0_f64, 1.0], [1.0, 3.0]];
209        let b = array![1.0_f64, 2.0];
210        let solver = MixedSolver::with_policy(5, 1e-12, policy);
211        let (x, stats) = solver.solve(&a, &b).expect("should succeed");
212
213        // Exact solution: x = [1/11, 7/11]
214        assert!((x[0] - 1.0 / 11.0).abs() < 1e-6, "x[0]={}", x[0]);
215        assert!((x[1] - 7.0 / 11.0).abs() < 1e-6, "x[1]={}", x[1]);
216        assert_eq!(stats.precision_used, Precision::Single);
217    }
218
219    #[test]
220    fn test_mixed_solver_dimension_mismatch() {
221        let a = Array2::<f64>::eye(3);
222        let b = Array1::<f64>::zeros(2);
223        let solver = MixedSolver::new(3, 1e-10);
224        assert!(solver.solve(&a, &b).is_err());
225    }
226
227    #[test]
228    fn test_mixed_solver_non_square() {
229        let a = Array2::<f64>::zeros((2, 3));
230        let b = Array1::<f64>::zeros(2);
231        let solver = MixedSolver::new(3, 1e-10);
232        assert!(solver.solve(&a, &b).is_err());
233    }
234
235    #[test]
236    fn test_mixed_solver_stats_precision() {
237        // A well-conditioned matrix should use Single
238        let a = array![[2.0_f64, 0.5], [0.5, 2.0]];
239        let b = array![1.0_f64, 1.0];
240        let solver = MixedSolver::new(3, 1e-10);
241        let (_x, stats) = solver.solve(&a, &b).expect("should succeed");
242        // The default policy threshold is 1e4; this matrix is well-conditioned
243        assert_eq!(stats.precision_used, Precision::Single);
244    }
245}