sublinear_solver/
solver_core.rs

1use crate::math_wasm::{Matrix, Vector};
2use std::fmt;
3
4#[derive(Debug, Clone)]
5pub struct SolverConfig {
6    pub max_iterations: usize,
7    pub tolerance: f64,
8}
9
10impl Default for SolverConfig {
11    fn default() -> Self {
12        Self {
13            max_iterations: 1000,
14            tolerance: 1e-10,
15        }
16    }
17}
18
19#[derive(Debug)]
20pub struct SolverError {
21    pub message: String,
22}
23
24impl fmt::Display for SolverError {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        write!(f, "Solver error: {}", self.message)
27    }
28}
29
30impl std::error::Error for SolverError {}
31
32pub struct StepData {
33    pub iteration: usize,
34    pub residual: f64,
35    pub converged: bool,
36    pub solution: Vector,
37}
38
39pub struct ConjugateGradientSolver {
40    config: SolverConfig,
41    last_iteration_count: usize,
42}
43
44impl ConjugateGradientSolver {
45    pub fn new(config: SolverConfig) -> Self {
46        Self {
47            config,
48            last_iteration_count: 0,
49        }
50    }
51
52    pub fn solve(&mut self, a: &Matrix, b: &Vector) -> Result<Vector, SolverError> {
53        self.validate_input(a, b)?;
54
55        let n = b.len();
56        let mut x = Vector::zeros(n);
57        let mut r = b.subtract(&a.multiply_vector(&x).unwrap());
58        let mut p = r.clone();
59        let mut rsold = r.dot(&r);
60
61        for iteration in 0..self.config.max_iterations {
62            let ap = a.multiply_vector(&p).unwrap();
63            let alpha = rsold / p.dot(&ap);
64
65            x.axpy(alpha, &p);
66            r.axpy(-alpha, &ap);
67
68            let rsnew = r.dot(&r);
69            let residual = rsnew.sqrt();
70
71            self.last_iteration_count = iteration + 1;
72
73            if residual < self.config.tolerance {
74                return Ok(x);
75            }
76
77            let beta = rsnew / rsold;
78            p = r.add(&p.scale(beta));
79            rsold = rsnew;
80        }
81
82        Err(SolverError {
83            message: format!(
84                "Failed to converge after {} iterations. Final residual: {}",
85                self.config.max_iterations,
86                rsold.sqrt()
87            ),
88        })
89    }
90
91    pub fn solve_with_callback<F>(
92        &mut self,
93        a: &Matrix,
94        b: &Vector,
95        chunk_size: usize,
96        mut callback: F,
97    ) -> Result<Vector, SolverError>
98    where
99        F: FnMut(StepData),
100    {
101        self.validate_input(a, b)?;
102
103        let n = b.len();
104        let mut x = Vector::zeros(n);
105        let mut r = b.subtract(&a.multiply_vector(&x).unwrap());
106        let mut p = r.clone();
107        let mut rsold = r.dot(&r);
108
109        for iteration in 0..self.config.max_iterations {
110            let ap = a.multiply_vector(&p).unwrap();
111            let alpha = rsold / p.dot(&ap);
112
113            x.axpy(alpha, &p);
114            r.axpy(-alpha, &ap);
115
116            let rsnew = r.dot(&r);
117            let residual = rsnew.sqrt();
118
119            let converged = residual < self.config.tolerance;
120
121            // Call callback every chunk_size iterations or on convergence
122            if iteration % chunk_size == 0 || converged {
123                callback(StepData {
124                    iteration: iteration + 1,
125                    residual,
126                    converged,
127                    solution: x.clone(),
128                });
129            }
130
131            self.last_iteration_count = iteration + 1;
132
133            if converged {
134                return Ok(x);
135            }
136
137            let beta = rsnew / rsold;
138            p = r.add(&p.scale(beta));
139            rsold = rsnew;
140        }
141
142        Err(SolverError {
143            message: format!(
144                "Failed to converge after {} iterations. Final residual: {}",
145                self.config.max_iterations,
146                rsold.sqrt()
147            ),
148        })
149    }
150
151    pub fn get_last_iteration_count(&self) -> usize {
152        self.last_iteration_count
153    }
154
155    fn validate_input(&self, a: &Matrix, b: &Vector) -> Result<(), SolverError> {
156        if a.rows() != a.cols() {
157            return Err(SolverError {
158                message: "Matrix must be square".to_string(),
159            });
160        }
161
162        if a.rows() != b.len() {
163            return Err(SolverError {
164                message: "Matrix rows must match vector length".to_string(),
165            });
166        }
167
168        if !a.is_symmetric() {
169            return Err(SolverError {
170                message: "Matrix must be symmetric for conjugate gradient".to_string(),
171            });
172        }
173
174        if !a.is_positive_definite() {
175            return Err(SolverError {
176                message: "Matrix must be positive definite for conjugate gradient".to_string(),
177            });
178        }
179
180        Ok(())
181    }
182}
183
184// Alternative solver for comparison and benchmarking
185pub struct JacobiSolver {
186    config: SolverConfig,
187    last_iteration_count: usize,
188}
189
190impl JacobiSolver {
191    pub fn new(config: SolverConfig) -> Self {
192        Self {
193            config,
194            last_iteration_count: 0,
195        }
196    }
197
198    pub fn solve(&mut self, a: &Matrix, b: &Vector) -> Result<Vector, SolverError> {
199        if a.rows() != a.cols() {
200            return Err(SolverError {
201                message: "Matrix must be square".to_string(),
202            });
203        }
204
205        if a.rows() != b.len() {
206            return Err(SolverError {
207                message: "Matrix rows must match vector length".to_string(),
208            });
209        }
210
211        let n = b.len();
212        let mut x = Vector::zeros(n);
213        let mut x_new = Vector::zeros(n);
214
215        for iteration in 0..self.config.max_iterations {
216            for i in 0..n {
217                let mut sum = 0.0;
218                for j in 0..n {
219                    if i != j {
220                        sum += a.get(i, j) * x.get(j);
221                    }
222                }
223                x_new.set(i, (b.get(i) - sum) / a.get(i, i));
224            }
225
226            // Check convergence
227            let diff = x_new.subtract(&x);
228            let residual = diff.norm();
229
230            self.last_iteration_count = iteration + 1;
231
232            if residual < self.config.tolerance {
233                return Ok(x_new);
234            }
235
236            x = x_new.clone();
237        }
238
239        Err(SolverError {
240            message: format!(
241                "Jacobi method failed to converge after {} iterations",
242                self.config.max_iterations
243            ),
244        })
245    }
246
247    pub fn get_last_iteration_count(&self) -> usize {
248        self.last_iteration_count
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_conjugate_gradient_simple() {
258        let config = SolverConfig {
259            max_iterations: 100,
260            tolerance: 1e-10,
261        };
262
263        let mut solver = ConjugateGradientSolver::new(config);
264
265        // Simple 2x2 positive definite system
266        // [4 1] [x1]   [1]
267        // [1 3] [x2] = [2]
268        let a = Matrix::from_slice(&[4.0, 1.0, 1.0, 3.0], 2, 2);
269        let b = Vector::from_slice(&[1.0, 2.0]);
270
271        let solution = solver.solve(&a, &b).unwrap();
272
273        // Verify solution by substituting back
274        let result = a.multiply_vector(&solution).unwrap();
275        let error = result.subtract(&b).norm();
276
277        assert!(error < 1e-10, "Solution error too large: {}", error);
278    }
279
280    #[test]
281    fn test_jacobi_simple() {
282        let config = SolverConfig {
283            max_iterations: 1000,
284            tolerance: 1e-6,
285        };
286
287        let mut solver = JacobiSolver::new(config);
288
289        // Diagonally dominant system for Jacobi convergence
290        // [4 1] [x1]   [1]
291        // [1 4] [x2] = [2]
292        let a = Matrix::from_slice(&[4.0, 1.0, 1.0, 4.0], 2, 2);
293        let b = Vector::from_slice(&[1.0, 2.0]);
294
295        let solution = solver.solve(&a, &b).unwrap();
296
297        // Verify solution
298        let result = a.multiply_vector(&solution).unwrap();
299        let error = result.subtract(&b).norm();
300
301        assert!(error < 1e-6, "Solution error too large: {}", error);
302    }
303
304    #[test]
305    fn test_solver_with_callback() {
306        let config = SolverConfig {
307            max_iterations: 100,
308            tolerance: 1e-10,
309        };
310
311        let mut solver = ConjugateGradientSolver::new(config);
312        let a = Matrix::from_slice(&[4.0, 1.0, 1.0, 3.0], 2, 2);
313        let b = Vector::from_slice(&[1.0, 2.0]);
314
315        let mut callback_count = 0;
316        let _solution = solver.solve_with_callback(&a, &b, 1, |_step| {
317            callback_count += 1;
318        }).unwrap();
319
320        assert!(callback_count > 0, "Callback should have been called");
321    }
322}