1use crate::math_wasm::{Matrix, Vector};
2use std::fmt;
3
4#[derive(Debug, Clone)]
5pub struct SolverConfig {
6 pub max_iterations: usize,
7 pub tolerance: f64,
8}
9
10impl Default for SolverConfig {
11 fn default() -> Self {
12 Self {
13 max_iterations: 1000,
14 tolerance: 1e-10,
15 }
16 }
17}
18
19#[derive(Debug)]
20pub struct SolverError {
21 pub message: String,
22}
23
24impl fmt::Display for SolverError {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 write!(f, "Solver error: {}", self.message)
27 }
28}
29
30impl std::error::Error for SolverError {}
31
32pub struct StepData {
33 pub iteration: usize,
34 pub residual: f64,
35 pub converged: bool,
36 pub solution: Vector,
37}
38
39pub struct ConjugateGradientSolver {
40 config: SolverConfig,
41 last_iteration_count: usize,
42}
43
44impl ConjugateGradientSolver {
45 pub fn new(config: SolverConfig) -> Self {
46 Self {
47 config,
48 last_iteration_count: 0,
49 }
50 }
51
52 pub fn solve(&mut self, a: &Matrix, b: &Vector) -> Result<Vector, SolverError> {
53 self.validate_input(a, b)?;
54
55 let n = b.len();
56 let mut x = Vector::zeros(n);
57 let mut r = b.subtract(&a.multiply_vector(&x).unwrap());
58 let mut p = r.clone();
59 let mut rsold = r.dot(&r);
60
61 for iteration in 0..self.config.max_iterations {
62 let ap = a.multiply_vector(&p).unwrap();
63 let alpha = rsold / p.dot(&ap);
64
65 x.axpy(alpha, &p);
66 r.axpy(-alpha, &ap);
67
68 let rsnew = r.dot(&r);
69 let residual = rsnew.sqrt();
70
71 self.last_iteration_count = iteration + 1;
72
73 if residual < self.config.tolerance {
74 return Ok(x);
75 }
76
77 let beta = rsnew / rsold;
78 p = r.add(&p.scale(beta));
79 rsold = rsnew;
80 }
81
82 Err(SolverError {
83 message: format!(
84 "Failed to converge after {} iterations. Final residual: {}",
85 self.config.max_iterations,
86 rsold.sqrt()
87 ),
88 })
89 }
90
91 pub fn solve_with_callback<F>(
92 &mut self,
93 a: &Matrix,
94 b: &Vector,
95 chunk_size: usize,
96 mut callback: F,
97 ) -> Result<Vector, SolverError>
98 where
99 F: FnMut(StepData),
100 {
101 self.validate_input(a, b)?;
102
103 let n = b.len();
104 let mut x = Vector::zeros(n);
105 let mut r = b.subtract(&a.multiply_vector(&x).unwrap());
106 let mut p = r.clone();
107 let mut rsold = r.dot(&r);
108
109 for iteration in 0..self.config.max_iterations {
110 let ap = a.multiply_vector(&p).unwrap();
111 let alpha = rsold / p.dot(&ap);
112
113 x.axpy(alpha, &p);
114 r.axpy(-alpha, &ap);
115
116 let rsnew = r.dot(&r);
117 let residual = rsnew.sqrt();
118
119 let converged = residual < self.config.tolerance;
120
121 if iteration % chunk_size == 0 || converged {
123 callback(StepData {
124 iteration: iteration + 1,
125 residual,
126 converged,
127 solution: x.clone(),
128 });
129 }
130
131 self.last_iteration_count = iteration + 1;
132
133 if converged {
134 return Ok(x);
135 }
136
137 let beta = rsnew / rsold;
138 p = r.add(&p.scale(beta));
139 rsold = rsnew;
140 }
141
142 Err(SolverError {
143 message: format!(
144 "Failed to converge after {} iterations. Final residual: {}",
145 self.config.max_iterations,
146 rsold.sqrt()
147 ),
148 })
149 }
150
151 pub fn get_last_iteration_count(&self) -> usize {
152 self.last_iteration_count
153 }
154
155 fn validate_input(&self, a: &Matrix, b: &Vector) -> Result<(), SolverError> {
156 if a.rows() != a.cols() {
157 return Err(SolverError {
158 message: "Matrix must be square".to_string(),
159 });
160 }
161
162 if a.rows() != b.len() {
163 return Err(SolverError {
164 message: "Matrix rows must match vector length".to_string(),
165 });
166 }
167
168 if !a.is_symmetric() {
169 return Err(SolverError {
170 message: "Matrix must be symmetric for conjugate gradient".to_string(),
171 });
172 }
173
174 if !a.is_positive_definite() {
175 return Err(SolverError {
176 message: "Matrix must be positive definite for conjugate gradient".to_string(),
177 });
178 }
179
180 Ok(())
181 }
182}
183
184pub struct JacobiSolver {
186 config: SolverConfig,
187 last_iteration_count: usize,
188}
189
190impl JacobiSolver {
191 pub fn new(config: SolverConfig) -> Self {
192 Self {
193 config,
194 last_iteration_count: 0,
195 }
196 }
197
198 pub fn solve(&mut self, a: &Matrix, b: &Vector) -> Result<Vector, SolverError> {
199 if a.rows() != a.cols() {
200 return Err(SolverError {
201 message: "Matrix must be square".to_string(),
202 });
203 }
204
205 if a.rows() != b.len() {
206 return Err(SolverError {
207 message: "Matrix rows must match vector length".to_string(),
208 });
209 }
210
211 let n = b.len();
212 let mut x = Vector::zeros(n);
213 let mut x_new = Vector::zeros(n);
214
215 for iteration in 0..self.config.max_iterations {
216 for i in 0..n {
217 let mut sum = 0.0;
218 for j in 0..n {
219 if i != j {
220 sum += a.get(i, j) * x.get(j);
221 }
222 }
223 x_new.set(i, (b.get(i) - sum) / a.get(i, i));
224 }
225
226 let diff = x_new.subtract(&x);
228 let residual = diff.norm();
229
230 self.last_iteration_count = iteration + 1;
231
232 if residual < self.config.tolerance {
233 return Ok(x_new);
234 }
235
236 x = x_new.clone();
237 }
238
239 Err(SolverError {
240 message: format!(
241 "Jacobi method failed to converge after {} iterations",
242 self.config.max_iterations
243 ),
244 })
245 }
246
247 pub fn get_last_iteration_count(&self) -> usize {
248 self.last_iteration_count
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_conjugate_gradient_simple() {
258 let config = SolverConfig {
259 max_iterations: 100,
260 tolerance: 1e-10,
261 };
262
263 let mut solver = ConjugateGradientSolver::new(config);
264
265 let a = Matrix::from_slice(&[4.0, 1.0, 1.0, 3.0], 2, 2);
269 let b = Vector::from_slice(&[1.0, 2.0]);
270
271 let solution = solver.solve(&a, &b).unwrap();
272
273 let result = a.multiply_vector(&solution).unwrap();
275 let error = result.subtract(&b).norm();
276
277 assert!(error < 1e-10, "Solution error too large: {}", error);
278 }
279
280 #[test]
281 fn test_jacobi_simple() {
282 let config = SolverConfig {
283 max_iterations: 1000,
284 tolerance: 1e-6,
285 };
286
287 let mut solver = JacobiSolver::new(config);
288
289 let a = Matrix::from_slice(&[4.0, 1.0, 1.0, 4.0], 2, 2);
293 let b = Vector::from_slice(&[1.0, 2.0]);
294
295 let solution = solver.solve(&a, &b).unwrap();
296
297 let result = a.multiply_vector(&solution).unwrap();
299 let error = result.subtract(&b).norm();
300
301 assert!(error < 1e-6, "Solution error too large: {}", error);
302 }
303
304 #[test]
305 fn test_solver_with_callback() {
306 let config = SolverConfig {
307 max_iterations: 100,
308 tolerance: 1e-10,
309 };
310
311 let mut solver = ConjugateGradientSolver::new(config);
312 let a = Matrix::from_slice(&[4.0, 1.0, 1.0, 3.0], 2, 2);
313 let b = Vector::from_slice(&[1.0, 2.0]);
314
315 let mut callback_count = 0;
316 let _solution = solver.solve_with_callback(&a, &b, 1, |_step| {
317 callback_count += 1;
318 }).unwrap();
319
320 assert!(callback_count > 0, "Callback should have been called");
321 }
322}