1use crate::core::{Matrix, Vector, SparseMatrix, Complexity};
4use crate::FTLError;
5use std::time::{Duration, Instant};
6use rayon::prelude::*;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
11pub enum SolverMethod {
12 Neumann,
14 RandomWalk,
16 ForwardPush,
18 BackwardPush,
20 Bidirectional,
22 Adaptive,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SolverConfig {
29 pub method: SolverMethod,
30 pub epsilon: f64,
31 pub max_iterations: usize,
32 pub parallel: bool,
33 pub timeout: Duration,
34}
35
36impl Default for SolverConfig {
37 fn default() -> Self {
38 Self {
39 method: SolverMethod::Adaptive,
40 epsilon: 1e-6,
41 max_iterations: 100,
42 parallel: true,
43 timeout: Duration::from_millis(100),
44 }
45 }
46}
47
48pub struct SublinearSolver {
50 config: SolverConfig,
51}
52
53impl SublinearSolver {
54 pub fn new() -> Self {
56 Self {
57 config: SolverConfig::default(),
58 }
59 }
60
61 pub fn with_method(method: SolverMethod) -> Self {
63 let mut config = SolverConfig::default();
64 config.method = method;
65 Self { config }
66 }
67
68 pub fn with_config(config: SolverConfig) -> Self {
70 Self { config }
71 }
72
73 pub fn solve(&self, a: &Matrix, b: &Vector) -> crate::Result<SolverResult> {
75 let start = Instant::now();
76
77 self.validate_inputs(a, b)?;
79
80 let method = if self.config.method == SolverMethod::Adaptive {
82 self.select_best_method(a)
83 } else {
84 self.config.method
85 };
86
87 let solution = match method {
89 SolverMethod::Neumann => self.solve_neumann(a, b)?,
90 SolverMethod::RandomWalk => self.solve_random_walk(a, b)?,
91 SolverMethod::ForwardPush => self.solve_forward_push(a, b)?,
92 SolverMethod::BackwardPush => self.solve_backward_push(a, b)?,
93 SolverMethod::Bidirectional => self.solve_bidirectional(a, b)?,
94 SolverMethod::Adaptive => unreachable!(),
95 };
96
97 let elapsed = start.elapsed();
98
99 let complexity = self.estimate_complexity(a.shape().0, elapsed);
101
102 Ok(SolverResult {
103 solution,
104 method,
105 iterations: self.config.max_iterations,
106 residual: self.compute_residual(a, &solution, b),
107 time: elapsed,
108 complexity,
109 })
110 }
111
112 fn validate_inputs(&self, a: &Matrix, b: &Vector) -> crate::Result<()> {
114 let (rows, cols) = a.shape();
115
116 if rows != cols {
117 return Err(FTLError::MatrixError("Matrix must be square".to_string()));
118 }
119
120 if b.len() != rows {
121 return Err(FTLError::MatrixError(
122 "Vector dimension mismatch".to_string(),
123 ));
124 }
125
126 if !self.is_diagonally_dominant(a) {
128 log::warn!("Matrix is not diagonally dominant - convergence not guaranteed");
130 }
131
132 Ok(())
133 }
134
135 fn is_diagonally_dominant(&self, a: &Matrix) -> bool {
137 let (n, _) = a.shape();
138 let view = a.view();
139
140 for i in 0..n {
141 let diagonal = view[[i, i]].abs();
142 let mut off_diagonal_sum = 0.0;
143
144 for j in 0..n {
145 if i != j {
146 off_diagonal_sum += view[[i, j]].abs();
147 }
148 }
149
150 if diagonal <= off_diagonal_sum {
151 return false;
152 }
153 }
154
155 true
156 }
157
158 fn select_best_method(&self, a: &Matrix) -> SolverMethod {
160 let sparse = a.to_sparse();
161 let sparsity = sparse.sparsity();
162
163 if sparsity > 0.95 {
164 SolverMethod::ForwardPush
166 } else if self.is_diagonally_dominant(a) {
167 SolverMethod::Neumann
169 } else {
170 SolverMethod::Bidirectional
172 }
173 }
174
175 fn solve_neumann(&self, a: &Matrix, b: &Vector) -> crate::Result<Vector> {
177 let n = b.len();
178 let mut x = b.clone();
179 let identity_minus_a = self.compute_iteration_matrix(a)?;
180
181 let iterations = (n as f64).log2().ceil() as usize;
184 let actual_iterations = iterations.min(self.config.max_iterations);
185
186 for _ in 0..actual_iterations {
187 let mx = identity_minus_a.multiply_vector(&x);
188 let new_x = b.add(&mx);
189
190 let diff = new_x.sub(&x).norm();
192 if diff < self.config.epsilon {
193 return Ok(new_x);
194 }
195
196 x = new_x;
197 }
198
199 Ok(x)
200 }
201
202 fn solve_random_walk(&self, a: &Matrix, b: &Vector) -> crate::Result<Vector> {
204 use rand::Rng;
205 let mut rng = rand::thread_rng();
206 let n = b.len();
207 let mut solution = Vector::zeros(n);
208
209 let num_walks = ((n as f64).log2() * 100.0) as usize;
211 let walk_length = (n as f64).log2().ceil() as usize;
212
213 for i in 0..n {
214 let mut estimate = 0.0;
215
216 for _ in 0..num_walks {
217 let mut current = i;
219 let mut weight = 1.0;
220
221 for _ in 0..walk_length {
222 let next = rng.gen_range(0..n);
224 weight *= a.view()[[current, next]];
225 current = next;
226
227 if weight.abs() < 1e-10 {
228 break;
229 }
230 }
231
232 estimate += weight * b.view()[current];
233 }
234
235 solution.data[i] = estimate / num_walks as f64;
236 }
237
238 Ok(solution)
239 }
240
241 fn solve_forward_push(&self, a: &Matrix, b: &Vector) -> crate::Result<Vector> {
243 let n = b.len();
244 let mut solution = b.clone();
245 let mut residual = b.clone();
246
247 let threshold = self.config.epsilon / (n as f64).sqrt();
249 let max_pushes = (n as f64).log2().ceil() as usize * 10;
250
251 for _ in 0..max_pushes {
252 let mut max_residual = 0.0;
254 let mut max_idx = 0;
255
256 for i in 0..n {
257 if residual.data[i].abs() > max_residual {
258 max_residual = residual.data[i].abs();
259 max_idx = i;
260 }
261 }
262
263 if max_residual < threshold {
264 break;
265 }
266
267 let push_value = residual.data[max_idx];
269 solution.data[max_idx] += push_value;
270
271 for j in 0..n {
273 residual.data[j] -= push_value * a.view()[[max_idx, j]];
274 }
275 residual.data[max_idx] = 0.0;
276 }
277
278 Ok(solution)
279 }
280
281 fn solve_backward_push(&self, a: &Matrix, b: &Vector) -> crate::Result<Vector> {
283 self.solve_forward_push(a, b) }
286
287 fn solve_bidirectional(&self, a: &Matrix, b: &Vector) -> crate::Result<Vector> {
289 let forward = self.solve_forward_push(a, b)?;
290 let backward = self.solve_backward_push(a, b)?;
291
292 Ok(forward.add(&backward).scale(0.5))
294 }
295
296 fn compute_iteration_matrix(&self, a: &Matrix) -> crate::Result<Matrix> {
298 let (n, _) = a.shape();
299 let mut m = Matrix::random(n, n);
300
301 for i in 0..n {
303 for j in 0..n {
304 if i == j {
305 m.data[[i, j]] = 0.0;
306 } else {
307 let diagonal = a.view()[[i, i]];
308 if diagonal.abs() > 1e-10 {
309 m.data[[i, j]] = -a.view()[[i, j]] / diagonal;
310 }
311 }
312 }
313 }
314
315 Ok(m)
316 }
317
318 fn compute_residual(&self, a: &Matrix, x: &Vector, b: &Vector) -> f64 {
320 let ax = a.multiply_vector(x);
321 ax.sub(b).norm()
322 }
323
324 fn estimate_complexity(&self, n: usize, elapsed: Duration) -> Complexity {
326 let nanos = elapsed.as_nanos() as f64;
327 let log_n = (n as f64).log2();
328
329 let ratios = vec![
331 (Complexity::Constant, 1.0),
332 (Complexity::Logarithmic, log_n),
333 (Complexity::Linear, n as f64),
334 (Complexity::Quadratic, (n * n) as f64),
335 (Complexity::Cubic, (n * n * n) as f64),
336 ];
337
338 let mut best_complexity = Complexity::Cubic;
340 let mut min_diff = f64::MAX;
341
342 for (complexity, theoretical) in ratios {
343 let diff = (nanos / theoretical - 1.0).abs();
344 if diff < min_diff {
345 min_diff = diff;
346 best_complexity = complexity;
347 }
348 }
349
350 best_complexity
351 }
352}
353
354#[derive(Debug, Clone)]
356pub struct SolverResult {
357 pub solution: Vector,
358 pub method: SolverMethod,
359 pub iterations: usize,
360 pub residual: f64,
361 pub time: Duration,
362 pub complexity: Complexity,
363}
364
365impl SolverResult {
366 pub fn converged(&self, tolerance: f64) -> bool {
368 self.residual < tolerance
369 }
370
371 pub fn time_microseconds(&self) -> f64 {
373 self.time.as_secs_f64() * 1_000_000.0
374 }
375
376 pub fn is_sublinear(&self) -> bool {
378 matches!(
379 self.complexity,
380 Complexity::Constant | Complexity::Logarithmic
381 )
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 #[test]
390 fn test_neumann_solver() {
391 let a = Matrix::diagonally_dominant(10, 2.0);
392 let b = Vector::ones(10);
393 let solver = SublinearSolver::with_method(SolverMethod::Neumann);
394
395 let result = solver.solve(&a, &b).unwrap();
396 assert!(result.is_sublinear());
397 }
398
399 #[test]
400 fn test_forward_push() {
401 let a = Matrix::diagonally_dominant(100, 3.0);
402 let b = Vector::random(100);
403 let solver = SublinearSolver::with_method(SolverMethod::ForwardPush);
404
405 let result = solver.solve(&a, &b).unwrap();
406 assert!(result.time_microseconds() < 1000.0); }
408
409 #[test]
410 fn test_adaptive_selection() {
411 let sparse = Matrix::diagonally_dominant(50, 5.0);
412 let b = Vector::ones(50);
413 let solver = SublinearSolver::new(); let result = solver.solve(&sparse, &b).unwrap();
416 assert!(result.converged(1e-3));
417 }
418}