Skip to main content

sublinear_solver/solver/
mod.rs

1//! Sublinear-time solver algorithms for asymmetric diagonally dominant systems.
2//!
3//! This module implements the core solver algorithms including Neumann series,
4//! forward/backward push methods, and hybrid random-walk approaches.
5
6use crate::error::{Result, SolverError};
7use crate::matrix::Matrix;
8use crate::types::{
9    ConvergenceMode, ErrorBounds, MemoryInfo, NormType, Precision, ProfileData, SolverStats,
10};
11use alloc::{string::String, vec::Vec};
12
13pub mod neumann;
14
15// Re-export solver implementations
16pub use neumann::NeumannSolver;
17
18/// Configuration options for solver algorithms.
19#[derive(Debug, Clone, PartialEq)]
20#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
21pub struct SolverOptions {
22    /// Convergence tolerance
23    pub tolerance: Precision,
24    /// Maximum number of iterations
25    pub max_iterations: usize,
26    /// Convergence detection mode
27    pub convergence_mode: ConvergenceMode,
28    /// Norm type for error measurement
29    pub norm_type: NormType,
30    /// Enable detailed statistics collection
31    pub collect_stats: bool,
32    /// Streaming solution interval (0 = no streaming)
33    pub streaming_interval: usize,
34    /// Initial guess for the solution (if None, use zero)
35    pub initial_guess: Option<Vec<Precision>>,
36    /// Enable error bounds computation
37    pub compute_error_bounds: bool,
38    /// Relative tolerance for error bounds
39    pub error_bounds_tolerance: Precision,
40    /// Enable performance profiling
41    pub enable_profiling: bool,
42    /// Random seed for stochastic algorithms
43    pub random_seed: Option<u64>,
44    /// Coherence gate threshold (ADR-001 roadmap item #3). If the matrix's
45    /// diagonal-dominance margin (`coherence::coherence_score`) falls below
46    /// this value, the solver returns `Err(SolverError::Incoherent { .. })`
47    /// *before* doing any iterative work — refusing to spend polynomial
48    /// cost on a near-singular system that can only produce an ε-quality
49    /// answer.
50    ///
51    /// `0.0` (the default) disables the gate, preserving wire compatibility
52    /// with every existing caller. Recommended starting value when enabling:
53    /// `0.05`.
54    pub coherence_threshold: Precision,
55}
56
57impl Default for SolverOptions {
58    fn default() -> Self {
59        Self {
60            tolerance: 1e-6,
61            max_iterations: 1000,
62            convergence_mode: ConvergenceMode::ResidualNorm,
63            norm_type: NormType::L2,
64            collect_stats: false,
65            streaming_interval: 0,
66            initial_guess: None,
67            compute_error_bounds: false,
68            error_bounds_tolerance: 1e-8,
69            enable_profiling: false,
70            random_seed: None,
71            // Gate disabled by default. Callers opt in by setting > 0.
72            coherence_threshold: 0.0,
73        }
74    }
75}
76
77impl SolverOptions {
78    /// Create options optimized for high precision.
79    pub fn high_precision() -> Self {
80        Self {
81            tolerance: 1e-12,
82            max_iterations: 5000,
83            convergence_mode: ConvergenceMode::Combined,
84            norm_type: NormType::L2,
85            collect_stats: true,
86            streaming_interval: 0,
87            initial_guess: None,
88            compute_error_bounds: true,
89            error_bounds_tolerance: 1e-14,
90            enable_profiling: false,
91            random_seed: None,
92            coherence_threshold: 0.0,
93        }
94    }
95
96    /// Create options optimized for fast solving.
97    pub fn fast() -> Self {
98        Self {
99            tolerance: 1e-3,
100            max_iterations: 100,
101            convergence_mode: ConvergenceMode::ResidualNorm,
102            norm_type: NormType::L2,
103            collect_stats: false,
104            streaming_interval: 0,
105            initial_guess: None,
106            compute_error_bounds: false,
107            error_bounds_tolerance: 1e-4,
108            enable_profiling: false,
109            random_seed: None,
110            coherence_threshold: 0.0,
111        }
112    }
113
114    /// Create options optimized for streaming applications.
115    pub fn streaming(interval: usize) -> Self {
116        Self {
117            tolerance: 1e-4,
118            max_iterations: 1000,
119            convergence_mode: ConvergenceMode::ResidualNorm,
120            norm_type: NormType::L2,
121            collect_stats: true,
122            streaming_interval: interval,
123            initial_guess: None,
124            compute_error_bounds: false,
125            error_bounds_tolerance: 1e-6,
126            enable_profiling: true,
127            random_seed: None,
128            coherence_threshold: 0.0,
129        }
130    }
131}
132
133/// Result of a solver computation.
134#[derive(Debug, Clone, PartialEq)]
135#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
136pub struct SolverResult {
137    /// Final solution vector
138    pub solution: Vec<Precision>,
139    /// Final residual norm
140    pub residual_norm: Precision,
141    /// Number of iterations performed
142    pub iterations: usize,
143    /// Whether the algorithm converged
144    pub converged: bool,
145    /// Error bounds (if computed)
146    pub error_bounds: Option<ErrorBounds>,
147    /// Detailed statistics (if collected)
148    pub stats: Option<SolverStats>,
149    /// Memory usage information
150    pub memory_info: Option<MemoryInfo>,
151    /// Performance profiling data
152    pub profile_data: Option<Vec<ProfileData>>,
153}
154
155impl SolverResult {
156    /// Create a successful result.
157    pub fn success(solution: Vec<Precision>, residual_norm: Precision, iterations: usize) -> Self {
158        Self {
159            solution,
160            residual_norm,
161            iterations,
162            converged: true,
163            error_bounds: None,
164            stats: None,
165            memory_info: None,
166            profile_data: None,
167        }
168    }
169
170    /// Create a failure result.
171    pub fn failure(solution: Vec<Precision>, residual_norm: Precision, iterations: usize) -> Self {
172        Self {
173            solution,
174            residual_norm,
175            iterations,
176            converged: false,
177            error_bounds: None,
178            stats: None,
179            memory_info: None,
180            profile_data: None,
181        }
182    }
183
184    /// Create an error result.
185    pub fn error(error: SolverError) -> Self {
186        Self {
187            solution: Vec::new(),
188            residual_norm: Precision::INFINITY,
189            iterations: 0,
190            converged: false,
191            error_bounds: None,
192            stats: None,
193            memory_info: None,
194            profile_data: None,
195        }
196    }
197
198    /// Check if the solution meets the specified quality criteria.
199    pub fn meets_quality_criteria(&self, tolerance: Precision) -> bool {
200        self.converged && self.residual_norm <= tolerance
201    }
202}
203
204/// Partial solution for streaming applications.
205#[derive(Debug, Clone, PartialEq)]
206#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
207pub struct PartialSolution {
208    /// Current iteration number
209    pub iteration: usize,
210    /// Current solution estimate
211    pub solution: Vec<Precision>,
212    /// Current residual norm
213    pub residual_norm: Precision,
214    /// Whether convergence has been achieved
215    pub converged: bool,
216    /// Estimated remaining iterations
217    pub estimated_remaining: Option<usize>,
218    /// Timestamp when this solution was computed (not serialized)
219    #[cfg(feature = "std")]
220    #[cfg_attr(feature = "serde", serde(skip, default = "std::time::Instant::now"))]
221    pub timestamp: std::time::Instant,
222    #[cfg(not(feature = "std"))]
223    pub timestamp: u64,
224}
225
226/// Core trait for all solver algorithms.
227///
228/// This trait defines the interface that all sublinear-time solvers must implement,
229/// providing both batch and streaming solution capabilities.
230pub trait SolverAlgorithm: Send + Sync {
231    /// Solver-specific state type
232    type State: SolverState;
233
234    /// Initialize the solver state for a given problem.
235    fn initialize(
236        &self,
237        matrix: &dyn Matrix,
238        b: &[Precision],
239        options: &SolverOptions,
240    ) -> Result<Self::State>;
241
242    /// Perform a single iteration step.
243    fn step(&self, state: &mut Self::State) -> Result<StepResult>;
244
245    /// Check if the current state meets convergence criteria.
246    fn is_converged(&self, state: &Self::State) -> bool;
247
248    /// Extract the current solution from the state.
249    fn extract_solution(&self, state: &Self::State) -> Vec<Precision>;
250
251    /// Update the right-hand side for incremental solving.
252    fn update_rhs(&self, state: &mut Self::State, delta_b: &[(usize, Precision)]) -> Result<()>;
253
254    /// Get the algorithm name for identification.
255    fn algorithm_name(&self) -> &'static str;
256
257    /// Solve the linear system Ax = b.
258    ///
259    /// This is the main interface for solving linear systems. It handles
260    /// the iteration loop and convergence checking automatically.
261    fn solve(
262        &self,
263        matrix: &dyn Matrix,
264        b: &[Precision],
265        options: &SolverOptions,
266    ) -> Result<SolverResult> {
267        let mut state = self.initialize(matrix, b, options)?;
268        let mut iterations = 0;
269
270        #[cfg(feature = "std")]
271        let start_time = std::time::Instant::now();
272
273        while !self.is_converged(&state) && iterations < options.max_iterations {
274            match self.step(&mut state)? {
275                StepResult::Continue => {
276                    iterations += 1;
277
278                    // Check for numerical issues
279                    let residual = state.residual_norm();
280                    if !residual.is_finite() {
281                        return Err(SolverError::NumericalInstability {
282                            reason: "Non-finite residual norm".to_string(),
283                            iteration: iterations,
284                            residual_norm: residual,
285                        });
286                    }
287                }
288                StepResult::Converged => break,
289                StepResult::Failed(reason) => {
290                    return Err(SolverError::AlgorithmError {
291                        algorithm: self.algorithm_name().to_string(),
292                        message: reason,
293                        context: vec![
294                            ("iteration".to_string(), iterations.to_string()),
295                            (
296                                "residual_norm".to_string(),
297                                state.residual_norm().to_string(),
298                            ),
299                        ],
300                    });
301                }
302            }
303        }
304
305        let converged = self.is_converged(&state);
306        let solution = self.extract_solution(&state);
307        let residual_norm = state.residual_norm();
308
309        // Check for convergence failure
310        if !converged && iterations >= options.max_iterations {
311            return Err(SolverError::ConvergenceFailure {
312                iterations,
313                residual_norm,
314                tolerance: options.tolerance,
315                algorithm: self.algorithm_name().to_string(),
316            });
317        }
318
319        let mut result = if converged {
320            SolverResult::success(solution, residual_norm, iterations)
321        } else {
322            SolverResult::failure(solution, residual_norm, iterations)
323        };
324
325        // Add optional data if requested
326        if options.collect_stats {
327            #[cfg(feature = "std")]
328            {
329                let total_time = start_time.elapsed().as_millis() as f64;
330                let mut stats = SolverStats::new();
331                stats.total_time_ms = total_time;
332                stats.matvec_count = state.matvec_count();
333                result.stats = Some(stats);
334            }
335        }
336
337        if options.compute_error_bounds {
338            result.error_bounds = state.error_bounds();
339        }
340
341        Ok(result)
342    }
343}
344
345/// Trait for solver state management.
346pub trait SolverState: Send + Sync {
347    /// Get the current residual norm.
348    fn residual_norm(&self) -> Precision;
349
350    /// Get the number of matrix-vector multiplications performed.
351    fn matvec_count(&self) -> usize;
352
353    /// Get error bounds if available.
354    fn error_bounds(&self) -> Option<ErrorBounds>;
355
356    /// Get current memory usage.
357    fn memory_usage(&self) -> MemoryInfo;
358
359    /// Reset the state for a new solve.
360    fn reset(&mut self);
361}
362
363/// Result of a single iteration step.
364#[derive(Debug, Clone, PartialEq)]
365pub enum StepResult {
366    /// Continue iterating
367    Continue,
368    /// Convergence achieved
369    Converged,
370    /// Algorithm failed with reason
371    Failed(String),
372}
373
374/// Utility functions for solver implementations.
375pub mod utils {
376    use super::*;
377
378    /// Compute the L2 norm of a vector.
379    pub fn l2_norm(v: &[Precision]) -> Precision {
380        v.iter().map(|x| x * x).sum::<Precision>().sqrt()
381    }
382
383    /// Compute the L1 norm of a vector.
384    pub fn l1_norm(v: &[Precision]) -> Precision {
385        v.iter().map(|x| x.abs()).sum()
386    }
387
388    /// Compute the L∞ norm of a vector.
389    pub fn linf_norm(v: &[Precision]) -> Precision {
390        v.iter().map(|x| x.abs()).fold(0.0, Precision::max)
391    }
392
393    /// Compute vector norm according to specified type.
394    pub fn compute_norm(v: &[Precision], norm_type: NormType) -> Precision {
395        match norm_type {
396            NormType::L1 => l1_norm(v),
397            NormType::L2 => l2_norm(v),
398            NormType::LInfinity => linf_norm(v),
399            NormType::Weighted => l2_norm(v), // Default to L2 for weighted
400        }
401    }
402
403    /// Compute residual vector: r = A*x - b
404    pub fn compute_residual(
405        matrix: &dyn Matrix,
406        x: &[Precision],
407        b: &[Precision],
408        residual: &mut [Precision],
409    ) -> Result<()> {
410        matrix.multiply_vector(x, residual)?;
411        for (r, &b_val) in residual.iter_mut().zip(b.iter()) {
412            *r -= b_val;
413        }
414        Ok(())
415    }
416
417    /// Check convergence based on specified criteria.
418    pub fn check_convergence(
419        residual_norm: Precision,
420        tolerance: Precision,
421        mode: ConvergenceMode,
422        b_norm: Precision,
423        prev_solution: Option<&[Precision]>,
424        current_solution: &[Precision],
425    ) -> bool {
426        match mode {
427            ConvergenceMode::ResidualNorm => residual_norm <= tolerance,
428            ConvergenceMode::RelativeResidual => {
429                if b_norm > 0.0 {
430                    (residual_norm / b_norm) <= tolerance
431                } else {
432                    residual_norm <= tolerance
433                }
434            }
435            ConvergenceMode::SolutionChange => {
436                if let Some(prev) = prev_solution {
437                    let mut change_norm = 0.0;
438                    for (&curr, &prev_val) in current_solution.iter().zip(prev.iter()) {
439                        let diff = curr - prev_val;
440                        change_norm += diff * diff;
441                    }
442                    change_norm.sqrt() <= tolerance
443                } else {
444                    false
445                }
446            }
447            ConvergenceMode::RelativeSolutionChange => {
448                if let Some(prev) = prev_solution {
449                    let mut change_norm = 0.0;
450                    let mut solution_norm = 0.0;
451                    for (&curr, &prev_val) in current_solution.iter().zip(prev.iter()) {
452                        let diff = curr - prev_val;
453                        change_norm += diff * diff;
454                        solution_norm += prev_val * prev_val;
455                    }
456                    if solution_norm > 0.0 {
457                        (change_norm.sqrt() / solution_norm.sqrt()) <= tolerance
458                    } else {
459                        change_norm.sqrt() <= tolerance
460                    }
461                } else {
462                    false
463                }
464            }
465            ConvergenceMode::Combined => {
466                // Use the most conservative criterion
467                residual_norm <= tolerance
468                    && (b_norm == 0.0 || (residual_norm / b_norm) <= tolerance)
469            }
470        }
471    }
472}
473
474// Forward declarations for solver implementations that will be added
475pub struct ForwardPushSolver;
476pub struct BackwardPushSolver;
477pub struct HybridSolver;
478
479// Placeholder implementations - will be implemented in separate modules
480impl SolverAlgorithm for ForwardPushSolver {
481    type State = ();
482
483    fn initialize(
484        &self,
485        _matrix: &dyn Matrix,
486        _b: &[Precision],
487        _options: &SolverOptions,
488    ) -> Result<Self::State> {
489        Err(SolverError::AlgorithmError {
490            algorithm: "forward_push".to_string(),
491            message: "Not implemented yet".to_string(),
492            context: vec![],
493        })
494    }
495
496    fn step(&self, _state: &mut Self::State) -> Result<StepResult> {
497        Err(SolverError::AlgorithmError {
498            algorithm: "forward_push".to_string(),
499            message: "Not implemented yet".to_string(),
500            context: vec![],
501        })
502    }
503
504    fn is_converged(&self, _state: &Self::State) -> bool {
505        false
506    }
507
508    fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> {
509        Vec::new()
510    }
511
512    fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> {
513        Err(SolverError::AlgorithmError {
514            algorithm: "forward_push".to_string(),
515            message: "Not implemented yet".to_string(),
516            context: vec![],
517        })
518    }
519
520    fn algorithm_name(&self) -> &'static str {
521        "forward_push"
522    }
523}
524
525impl SolverState for () {
526    fn residual_norm(&self) -> Precision {
527        0.0
528    }
529
530    fn matvec_count(&self) -> usize {
531        0
532    }
533
534    fn error_bounds(&self) -> Option<ErrorBounds> {
535        None
536    }
537
538    fn memory_usage(&self) -> MemoryInfo {
539        MemoryInfo {
540            current_usage_bytes: 0,
541            peak_usage_bytes: 0,
542            matrix_memory_bytes: 0,
543            vector_memory_bytes: 0,
544            workspace_memory_bytes: 0,
545            allocation_count: 0,
546            deallocation_count: 0,
547        }
548    }
549
550    fn reset(&mut self) {}
551}
552
553// Similar placeholder implementations for BackwardPushSolver and HybridSolver
554impl SolverAlgorithm for BackwardPushSolver {
555    type State = ();
556    fn initialize(
557        &self,
558        _matrix: &dyn Matrix,
559        _b: &[Precision],
560        _options: &SolverOptions,
561    ) -> Result<Self::State> {
562        Ok(())
563    }
564    fn step(&self, _state: &mut Self::State) -> Result<StepResult> {
565        Ok(StepResult::Converged)
566    }
567    fn is_converged(&self, _state: &Self::State) -> bool {
568        true
569    }
570    fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> {
571        Vec::new()
572    }
573    fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> {
574        Ok(())
575    }
576    fn algorithm_name(&self) -> &'static str {
577        "backward_push"
578    }
579}
580
581impl SolverAlgorithm for HybridSolver {
582    type State = ();
583    fn initialize(
584        &self,
585        _matrix: &dyn Matrix,
586        _b: &[Precision],
587        _options: &SolverOptions,
588    ) -> Result<Self::State> {
589        Ok(())
590    }
591    fn step(&self, _state: &mut Self::State) -> Result<StepResult> {
592        Ok(StepResult::Converged)
593    }
594    fn is_converged(&self, _state: &Self::State) -> bool {
595        true
596    }
597    fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> {
598        Vec::new()
599    }
600    fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> {
601        Ok(())
602    }
603    fn algorithm_name(&self) -> &'static str {
604        "hybrid"
605    }
606}
607
608#[cfg(all(test, feature = "std"))]
609mod tests {
610    use super::*;
611    use crate::matrix::SparseMatrix;
612
613    #[test]
614    fn test_solver_options() {
615        let default_opts = SolverOptions::default();
616        assert_eq!(default_opts.tolerance, 1e-6);
617        assert_eq!(default_opts.max_iterations, 1000);
618
619        let fast_opts = SolverOptions::fast();
620        assert_eq!(fast_opts.tolerance, 1e-3);
621        assert_eq!(fast_opts.max_iterations, 100);
622
623        let precision_opts = SolverOptions::high_precision();
624        assert_eq!(precision_opts.tolerance, 1e-12);
625        assert!(precision_opts.compute_error_bounds);
626    }
627
628    #[test]
629    fn test_solver_result() {
630        let result = SolverResult::success(vec![1.0, 2.0], 1e-8, 10);
631        assert!(result.converged);
632        assert!(result.meets_quality_criteria(1e-6));
633        assert!(!result.meets_quality_criteria(1e-10));
634    }
635
636    #[test]
637    fn test_norm_calculations() {
638        use utils::*;
639
640        let v = vec![3.0, 4.0];
641        assert_eq!(l1_norm(&v), 7.0);
642        assert_eq!(l2_norm(&v), 5.0);
643        assert_eq!(linf_norm(&v), 4.0);
644    }
645}