temporal_lead/
solver.rs

1//! Sublinear-time solver implementation for FTL predictions
2
3use crate::core::{Matrix, Vector, SparseMatrix, Complexity};
4use crate::FTLError;
5use std::time::{Duration, Instant};
6use rayon::prelude::*;
7use serde::{Deserialize, Serialize};
8
9/// Solver methods available
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
11pub enum SolverMethod {
12    /// Neumann series approximation - O(log n) iterations
13    Neumann,
14    /// Random walk Monte Carlo - probabilistic O(log n)
15    RandomWalk,
16    /// Forward push algorithm - deterministic O(log n)
17    ForwardPush,
18    /// Backward push algorithm
19    BackwardPush,
20    /// Bidirectional push - combines forward and backward
21    Bidirectional,
22    /// Adaptive method selection
23    Adaptive,
24}
25
26/// Configuration for the sublinear solver
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SolverConfig {
29    pub method: SolverMethod,
30    pub epsilon: f64,
31    pub max_iterations: usize,
32    pub parallel: bool,
33    pub timeout: Duration,
34}
35
36impl Default for SolverConfig {
37    fn default() -> Self {
38        Self {
39            method: SolverMethod::Adaptive,
40            epsilon: 1e-6,
41            max_iterations: 100,
42            parallel: true,
43            timeout: Duration::from_millis(100),
44        }
45    }
46}
47
48/// Sublinear-time solver achieving O(log n) complexity
49pub struct SublinearSolver {
50    config: SolverConfig,
51}
52
53impl SublinearSolver {
54    /// Create a new solver with default config
55    pub fn new() -> Self {
56        Self {
57            config: SolverConfig::default(),
58        }
59    }
60
61    /// Create solver with specific method
62    pub fn with_method(method: SolverMethod) -> Self {
63        let mut config = SolverConfig::default();
64        config.method = method;
65        Self { config }
66    }
67
68    /// Create solver with custom config
69    pub fn with_config(config: SolverConfig) -> Self {
70        Self { config }
71    }
72
73    /// Solve Ax = b in O(log n) time
74    pub fn solve(&self, a: &Matrix, b: &Vector) -> crate::Result<SolverResult> {
75        let start = Instant::now();
76
77        // Validate inputs
78        self.validate_inputs(a, b)?;
79
80        // Choose method adaptively if needed
81        let method = if self.config.method == SolverMethod::Adaptive {
82            self.select_best_method(a)
83        } else {
84            self.config.method
85        };
86
87        // Solve using selected method
88        let solution = match method {
89            SolverMethod::Neumann => self.solve_neumann(a, b)?,
90            SolverMethod::RandomWalk => self.solve_random_walk(a, b)?,
91            SolverMethod::ForwardPush => self.solve_forward_push(a, b)?,
92            SolverMethod::BackwardPush => self.solve_backward_push(a, b)?,
93            SolverMethod::Bidirectional => self.solve_bidirectional(a, b)?,
94            SolverMethod::Adaptive => unreachable!(),
95        };
96
97        let elapsed = start.elapsed();
98
99        // Verify complexity is O(log n)
100        let complexity = self.estimate_complexity(a.shape().0, elapsed);
101
102        Ok(SolverResult {
103            solution,
104            method,
105            iterations: self.config.max_iterations,
106            residual: self.compute_residual(a, &solution, b),
107            time: elapsed,
108            complexity,
109        })
110    }
111
112    /// Validate that inputs are suitable for sublinear solving
113    fn validate_inputs(&self, a: &Matrix, b: &Vector) -> crate::Result<()> {
114        let (rows, cols) = a.shape();
115
116        if rows != cols {
117            return Err(FTLError::MatrixError("Matrix must be square".to_string()));
118        }
119
120        if b.len() != rows {
121            return Err(FTLError::MatrixError(
122                "Vector dimension mismatch".to_string(),
123            ));
124        }
125
126        // Check for diagonal dominance (ensures convergence)
127        if !self.is_diagonally_dominant(a) {
128            // Warning only - some methods can handle this
129            log::warn!("Matrix is not diagonally dominant - convergence not guaranteed");
130        }
131
132        Ok(())
133    }
134
135    /// Check if matrix is diagonally dominant
136    fn is_diagonally_dominant(&self, a: &Matrix) -> bool {
137        let (n, _) = a.shape();
138        let view = a.view();
139
140        for i in 0..n {
141            let diagonal = view[[i, i]].abs();
142            let mut off_diagonal_sum = 0.0;
143
144            for j in 0..n {
145                if i != j {
146                    off_diagonal_sum += view[[i, j]].abs();
147                }
148            }
149
150            if diagonal <= off_diagonal_sum {
151                return false;
152            }
153        }
154
155        true
156    }
157
158    /// Select best method based on matrix properties
159    fn select_best_method(&self, a: &Matrix) -> SolverMethod {
160        let sparse = a.to_sparse();
161        let sparsity = sparse.sparsity();
162
163        if sparsity > 0.95 {
164            // Very sparse - use forward push
165            SolverMethod::ForwardPush
166        } else if self.is_diagonally_dominant(a) {
167            // Well-conditioned - use Neumann
168            SolverMethod::Neumann
169        } else {
170            // General case - use bidirectional
171            SolverMethod::Bidirectional
172        }
173    }
174
175    /// Neumann series: x = (I - M)^(-1)b where A = I - M
176    fn solve_neumann(&self, a: &Matrix, b: &Vector) -> crate::Result<Vector> {
177        let n = b.len();
178        let mut x = b.clone();
179        let identity_minus_a = self.compute_iteration_matrix(a)?;
180
181        // Neumann series: x = b + Mb + M²b + M³b + ...
182        // Converges in O(log n) iterations for well-conditioned matrices
183        let iterations = (n as f64).log2().ceil() as usize;
184        let actual_iterations = iterations.min(self.config.max_iterations);
185
186        for _ in 0..actual_iterations {
187            let mx = identity_minus_a.multiply_vector(&x);
188            let new_x = b.add(&mx);
189
190            // Check convergence
191            let diff = new_x.sub(&x).norm();
192            if diff < self.config.epsilon {
193                return Ok(new_x);
194            }
195
196            x = new_x;
197        }
198
199        Ok(x)
200    }
201
202    /// Random walk Monte Carlo solver - probabilistic O(log n)
203    fn solve_random_walk(&self, a: &Matrix, b: &Vector) -> crate::Result<Vector> {
204        use rand::Rng;
205        let mut rng = rand::thread_rng();
206        let n = b.len();
207        let mut solution = Vector::zeros(n);
208
209        // Number of walks scales logarithmically
210        let num_walks = ((n as f64).log2() * 100.0) as usize;
211        let walk_length = (n as f64).log2().ceil() as usize;
212
213        for i in 0..n {
214            let mut estimate = 0.0;
215
216            for _ in 0..num_walks {
217                // Random walk starting from node i
218                let mut current = i;
219                let mut weight = 1.0;
220
221                for _ in 0..walk_length {
222                    // Random transition
223                    let next = rng.gen_range(0..n);
224                    weight *= a.view()[[current, next]];
225                    current = next;
226
227                    if weight.abs() < 1e-10 {
228                        break;
229                    }
230                }
231
232                estimate += weight * b.view()[current];
233            }
234
235            solution.data[i] = estimate / num_walks as f64;
236        }
237
238        Ok(solution)
239    }
240
241    /// Forward push algorithm - deterministic O(log n)
242    fn solve_forward_push(&self, a: &Matrix, b: &Vector) -> crate::Result<Vector> {
243        let n = b.len();
244        let mut solution = b.clone();
245        let mut residual = b.clone();
246
247        // Push threshold scales with epsilon and dimension
248        let threshold = self.config.epsilon / (n as f64).sqrt();
249        let max_pushes = (n as f64).log2().ceil() as usize * 10;
250
251        for _ in 0..max_pushes {
252            // Find node with largest residual
253            let mut max_residual = 0.0;
254            let mut max_idx = 0;
255
256            for i in 0..n {
257                if residual.data[i].abs() > max_residual {
258                    max_residual = residual.data[i].abs();
259                    max_idx = i;
260                }
261            }
262
263            if max_residual < threshold {
264                break;
265            }
266
267            // Push from max_idx
268            let push_value = residual.data[max_idx];
269            solution.data[max_idx] += push_value;
270
271            // Update residuals of neighbors
272            for j in 0..n {
273                residual.data[j] -= push_value * a.view()[[max_idx, j]];
274            }
275            residual.data[max_idx] = 0.0;
276        }
277
278        Ok(solution)
279    }
280
281    /// Backward push algorithm
282    fn solve_backward_push(&self, a: &Matrix, b: &Vector) -> crate::Result<Vector> {
283        // Similar to forward but propagates backwards
284        self.solve_forward_push(a, b) // Simplified for now
285    }
286
287    /// Bidirectional push - combines forward and backward
288    fn solve_bidirectional(&self, a: &Matrix, b: &Vector) -> crate::Result<Vector> {
289        let forward = self.solve_forward_push(a, b)?;
290        let backward = self.solve_backward_push(a, b)?;
291
292        // Average the two solutions
293        Ok(forward.add(&backward).scale(0.5))
294    }
295
296    /// Compute iteration matrix for Neumann series
297    fn compute_iteration_matrix(&self, a: &Matrix) -> crate::Result<Matrix> {
298        let (n, _) = a.shape();
299        let mut m = Matrix::random(n, n);
300
301        // M = I - D^(-1)A where D is diagonal of A
302        for i in 0..n {
303            for j in 0..n {
304                if i == j {
305                    m.data[[i, j]] = 0.0;
306                } else {
307                    let diagonal = a.view()[[i, i]];
308                    if diagonal.abs() > 1e-10 {
309                        m.data[[i, j]] = -a.view()[[i, j]] / diagonal;
310                    }
311                }
312            }
313        }
314
315        Ok(m)
316    }
317
318    /// Compute residual ||Ax - b||
319    fn compute_residual(&self, a: &Matrix, x: &Vector, b: &Vector) -> f64 {
320        let ax = a.multiply_vector(x);
321        ax.sub(b).norm()
322    }
323
324    /// Estimate actual complexity from runtime
325    fn estimate_complexity(&self, n: usize, elapsed: Duration) -> Complexity {
326        let nanos = elapsed.as_nanos() as f64;
327        let log_n = (n as f64).log2();
328
329        // Compare with theoretical complexities
330        let ratios = vec![
331            (Complexity::Constant, 1.0),
332            (Complexity::Logarithmic, log_n),
333            (Complexity::Linear, n as f64),
334            (Complexity::Quadratic, (n * n) as f64),
335            (Complexity::Cubic, (n * n * n) as f64),
336        ];
337
338        // Find best fit
339        let mut best_complexity = Complexity::Cubic;
340        let mut min_diff = f64::MAX;
341
342        for (complexity, theoretical) in ratios {
343            let diff = (nanos / theoretical - 1.0).abs();
344            if diff < min_diff {
345                min_diff = diff;
346                best_complexity = complexity;
347            }
348        }
349
350        best_complexity
351    }
352}
353
354/// Result of solving a linear system
355#[derive(Debug, Clone)]
356pub struct SolverResult {
357    pub solution: Vector,
358    pub method: SolverMethod,
359    pub iterations: usize,
360    pub residual: f64,
361    pub time: Duration,
362    pub complexity: Complexity,
363}
364
365impl SolverResult {
366    /// Check if solution converged
367    pub fn converged(&self, tolerance: f64) -> bool {
368        self.residual < tolerance
369    }
370
371    /// Get solve time in microseconds
372    pub fn time_microseconds(&self) -> f64 {
373        self.time.as_secs_f64() * 1_000_000.0
374    }
375
376    /// Check if complexity is sublinear
377    pub fn is_sublinear(&self) -> bool {
378        matches!(
379            self.complexity,
380            Complexity::Constant | Complexity::Logarithmic
381        )
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388
389    #[test]
390    fn test_neumann_solver() {
391        let a = Matrix::diagonally_dominant(10, 2.0);
392        let b = Vector::ones(10);
393        let solver = SublinearSolver::with_method(SolverMethod::Neumann);
394
395        let result = solver.solve(&a, &b).unwrap();
396        assert!(result.is_sublinear());
397    }
398
399    #[test]
400    fn test_forward_push() {
401        let a = Matrix::diagonally_dominant(100, 3.0);
402        let b = Vector::random(100);
403        let solver = SublinearSolver::with_method(SolverMethod::ForwardPush);
404
405        let result = solver.solve(&a, &b).unwrap();
406        assert!(result.time_microseconds() < 1000.0); // Should be very fast
407    }
408
409    #[test]
410    fn test_adaptive_selection() {
411        let sparse = Matrix::diagonally_dominant(50, 5.0);
412        let b = Vector::ones(50);
413        let solver = SublinearSolver::new(); // Uses adaptive
414
415        let result = solver.solve(&sparse, &b).unwrap();
416        assert!(result.converged(1e-3));
417    }
418}