1use scirs2_core::ndarray::{Array1, Array2};
22
23use crate::auto_precision::{solve_f32, solve_f64, Precision, PrecisionPolicy};
24use crate::error::LinalgError;
25
26#[derive(Debug, Clone)]
32pub struct SolverStats {
33 pub precision_used: Precision,
35 pub refinement_steps_done: usize,
37 pub final_residual: f64,
39}
40
41pub struct MixedSolver {
50 refinement_steps: usize,
52 tol: f64,
54 policy: PrecisionPolicy,
56}
57
58impl MixedSolver {
59 pub fn new(refinement_steps: usize, tol: f64) -> Self {
66 Self {
67 refinement_steps,
68 tol,
69 policy: PrecisionPolicy::default(),
70 }
71 }
72
73 pub fn with_policy(refinement_steps: usize, tol: f64, policy: PrecisionPolicy) -> Self {
75 Self {
76 refinement_steps,
77 tol,
78 policy,
79 }
80 }
81
82 pub fn solve(
91 &self,
92 a: &Array2<f64>,
93 b: &Array1<f64>,
94 ) -> Result<(Array1<f64>, SolverStats), LinalgError> {
95 let n = a.nrows();
96 if a.ncols() != n {
97 return Err(LinalgError::DimensionError(format!(
98 "MixedSolver requires a square matrix, got {}x{}",
99 n,
100 a.ncols()
101 )));
102 }
103 if b.len() != n {
104 return Err(LinalgError::DimensionError(format!(
105 "rhs length {} does not match matrix dimension {}",
106 b.len(),
107 n
108 )));
109 }
110
111 let precision = crate::auto_precision::select_precision(a, &self.policy);
113 let mut x = match precision {
114 Precision::Single => solve_f32(a, b)?,
115 Precision::Double | Precision::Mixed => solve_f64(a, b)?,
116 };
117
118 let mut steps_done = 0;
120 let mut final_res = residual_norm(a, b, &x);
121
122 for _ in 0..self.refinement_steps {
123 if final_res < self.tol {
124 break;
125 }
126 let r = compute_residual(a, b, &x);
128 let delta = solve_f64(a, &r)?;
130 for i in 0..n {
132 x[i] += delta[i];
133 }
134 steps_done += 1;
135 final_res = residual_norm(a, b, &x);
136 }
137
138 Ok((
139 x,
140 SolverStats {
141 precision_used: precision,
142 refinement_steps_done: steps_done,
143 final_residual: final_res,
144 },
145 ))
146 }
147}
148
149fn compute_residual(a: &Array2<f64>, b: &Array1<f64>, x: &Array1<f64>) -> Array1<f64> {
155 let n = a.nrows();
156 let mut r = b.to_owned();
157 for i in 0..n {
158 let mut ax_i = 0.0;
159 for j in 0..n {
160 ax_i += a[[i, j]] * x[j];
161 }
162 r[i] -= ax_i;
163 }
164 r
165}
166
167fn residual_norm(a: &Array2<f64>, b: &Array1<f64>, x: &Array1<f64>) -> f64 {
169 let r = compute_residual(a, b, x);
170 r.iter().map(|&ri| ri * ri).sum::<f64>().sqrt()
171}
172
173#[cfg(test)]
178mod tests {
179 use super::*;
180 use crate::auto_precision::Precision;
181 use scirs2_core::ndarray::array;
182
183 #[test]
184 fn test_mixed_solver_well_conditioned() {
185 let a = array![[2.0_f64, 1.0, -1.0], [-3.0, -1.0, 2.0], [-2.0, 1.0, 2.0]];
187 let b = array![8.0_f64, -11.0, -3.0];
188 let solver = MixedSolver::new(3, 1e-12);
189 let (x, stats) = solver.solve(&a, &b).expect("should succeed");
190
191 assert!((x[0] - 2.0).abs() < 1e-8, "x[0]={}", x[0]);
192 assert!((x[1] - 3.0).abs() < 1e-8, "x[1]={}", x[1]);
193 assert!((x[2] - (-1.0)).abs() < 1e-8, "x[2]={}", x[2]);
194 assert!(
195 stats.final_residual < 1e-10,
196 "residual={}",
197 stats.final_residual
198 );
199 }
200
201 #[test]
202 fn test_mixed_solver_force_single_refines() {
203 let policy = PrecisionPolicy {
205 force: Some(Precision::Single),
206 ..Default::default()
207 };
208 let a = array![[4.0_f64, 1.0], [1.0, 3.0]];
209 let b = array![1.0_f64, 2.0];
210 let solver = MixedSolver::with_policy(5, 1e-12, policy);
211 let (x, stats) = solver.solve(&a, &b).expect("should succeed");
212
213 assert!((x[0] - 1.0 / 11.0).abs() < 1e-6, "x[0]={}", x[0]);
215 assert!((x[1] - 7.0 / 11.0).abs() < 1e-6, "x[1]={}", x[1]);
216 assert_eq!(stats.precision_used, Precision::Single);
217 }
218
219 #[test]
220 fn test_mixed_solver_dimension_mismatch() {
221 let a = Array2::<f64>::eye(3);
222 let b = Array1::<f64>::zeros(2);
223 let solver = MixedSolver::new(3, 1e-10);
224 assert!(solver.solve(&a, &b).is_err());
225 }
226
227 #[test]
228 fn test_mixed_solver_non_square() {
229 let a = Array2::<f64>::zeros((2, 3));
230 let b = Array1::<f64>::zeros(2);
231 let solver = MixedSolver::new(3, 1e-10);
232 assert!(solver.solve(&a, &b).is_err());
233 }
234
235 #[test]
236 fn test_mixed_solver_stats_precision() {
237 let a = array![[2.0_f64, 0.5], [0.5, 2.0]];
239 let b = array![1.0_f64, 1.0];
240 let solver = MixedSolver::new(3, 1e-10);
241 let (_x, stats) = solver.solve(&a, &b).expect("should succeed");
242 assert_eq!(stats.precision_used, Precision::Single);
244 }
245}