sublinear_solver/solver/
mod.rs

1//! Sublinear-time solver algorithms for asymmetric diagonally dominant systems.
2//!
3//! This module implements the core solver algorithms including Neumann series,
4//! forward/backward push methods, and hybrid random-walk approaches.
5
6use crate::matrix::Matrix;
7use crate::types::{
8    Precision, ConvergenceMode, NormType, ErrorBounds, SolverStats,
9    DimensionType, MemoryInfo, ProfileData
10};
11use crate::error::{SolverError, Result};
12use alloc::{vec::Vec, string::String, boxed::Box};
13
14pub mod neumann;
15
16// Re-export solver implementations
17pub use neumann::NeumannSolver;
18
19/// Configuration options for solver algorithms.
20#[derive(Debug, Clone, PartialEq)]
21#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22pub struct SolverOptions {
23    /// Convergence tolerance
24    pub tolerance: Precision,
25    /// Maximum number of iterations
26    pub max_iterations: usize,
27    /// Convergence detection mode
28    pub convergence_mode: ConvergenceMode,
29    /// Norm type for error measurement
30    pub norm_type: NormType,
31    /// Enable detailed statistics collection
32    pub collect_stats: bool,
33    /// Streaming solution interval (0 = no streaming)
34    pub streaming_interval: usize,
35    /// Initial guess for the solution (if None, use zero)
36    pub initial_guess: Option<Vec<Precision>>,
37    /// Enable error bounds computation
38    pub compute_error_bounds: bool,
39    /// Relative tolerance for error bounds
40    pub error_bounds_tolerance: Precision,
41    /// Enable performance profiling
42    pub enable_profiling: bool,
43    /// Random seed for stochastic algorithms
44    pub random_seed: Option<u64>,
45}
46
47impl Default for SolverOptions {
48    fn default() -> Self {
49        Self {
50            tolerance: 1e-6,
51            max_iterations: 1000,
52            convergence_mode: ConvergenceMode::ResidualNorm,
53            norm_type: NormType::L2,
54            collect_stats: false,
55            streaming_interval: 0,
56            initial_guess: None,
57            compute_error_bounds: false,
58            error_bounds_tolerance: 1e-8,
59            enable_profiling: false,
60            random_seed: None,
61        }
62    }
63}
64
65impl SolverOptions {
66    /// Create options optimized for high precision.
67    pub fn high_precision() -> Self {
68        Self {
69            tolerance: 1e-12,
70            max_iterations: 5000,
71            convergence_mode: ConvergenceMode::Combined,
72            norm_type: NormType::L2,
73            collect_stats: true,
74            streaming_interval: 0,
75            initial_guess: None,
76            compute_error_bounds: true,
77            error_bounds_tolerance: 1e-14,
78            enable_profiling: false,
79            random_seed: None,
80        }
81    }
82
83    /// Create options optimized for fast solving.
84    pub fn fast() -> Self {
85        Self {
86            tolerance: 1e-3,
87            max_iterations: 100,
88            convergence_mode: ConvergenceMode::ResidualNorm,
89            norm_type: NormType::L2,
90            collect_stats: false,
91            streaming_interval: 0,
92            initial_guess: None,
93            compute_error_bounds: false,
94            error_bounds_tolerance: 1e-4,
95            enable_profiling: false,
96            random_seed: None,
97        }
98    }
99
100    /// Create options optimized for streaming applications.
101    pub fn streaming(interval: usize) -> Self {
102        Self {
103            tolerance: 1e-4,
104            max_iterations: 1000,
105            convergence_mode: ConvergenceMode::ResidualNorm,
106            norm_type: NormType::L2,
107            collect_stats: true,
108            streaming_interval: interval,
109            initial_guess: None,
110            compute_error_bounds: false,
111            error_bounds_tolerance: 1e-6,
112            enable_profiling: true,
113            random_seed: None,
114        }
115    }
116}
117
118/// Result of a solver computation.
119#[derive(Debug, Clone, PartialEq)]
120#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
121pub struct SolverResult {
122    /// Final solution vector
123    pub solution: Vec<Precision>,
124    /// Final residual norm
125    pub residual_norm: Precision,
126    /// Number of iterations performed
127    pub iterations: usize,
128    /// Whether the algorithm converged
129    pub converged: bool,
130    /// Error bounds (if computed)
131    pub error_bounds: Option<ErrorBounds>,
132    /// Detailed statistics (if collected)
133    pub stats: Option<SolverStats>,
134    /// Memory usage information
135    pub memory_info: Option<MemoryInfo>,
136    /// Performance profiling data
137    pub profile_data: Option<Vec<ProfileData>>,
138}
139
140impl SolverResult {
141    /// Create a successful result.
142    pub fn success(
143        solution: Vec<Precision>,
144        residual_norm: Precision,
145        iterations: usize,
146    ) -> Self {
147        Self {
148            solution,
149            residual_norm,
150            iterations,
151            converged: true,
152            error_bounds: None,
153            stats: None,
154            memory_info: None,
155            profile_data: None,
156        }
157    }
158
159    /// Create a failure result.
160    pub fn failure(
161        solution: Vec<Precision>,
162        residual_norm: Precision,
163        iterations: usize,
164    ) -> Self {
165        Self {
166            solution,
167            residual_norm,
168            iterations,
169            converged: false,
170            error_bounds: None,
171            stats: None,
172            memory_info: None,
173            profile_data: None,
174        }
175    }
176
177    /// Create an error result.
178    pub fn error(error: SolverError) -> Self {
179        Self {
180            solution: Vec::new(),
181            residual_norm: Precision::INFINITY,
182            iterations: 0,
183            converged: false,
184            error_bounds: None,
185            stats: None,
186            memory_info: None,
187            profile_data: None,
188        }
189    }
190
191    /// Check if the solution meets the specified quality criteria.
192    pub fn meets_quality_criteria(&self, tolerance: Precision) -> bool {
193        self.converged && self.residual_norm <= tolerance
194    }
195}
196
197/// Partial solution for streaming applications.
198#[derive(Debug, Clone, PartialEq)]
199#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
200pub struct PartialSolution {
201    /// Current iteration number
202    pub iteration: usize,
203    /// Current solution estimate
204    pub solution: Vec<Precision>,
205    /// Current residual norm
206    pub residual_norm: Precision,
207    /// Whether convergence has been achieved
208    pub converged: bool,
209    /// Estimated remaining iterations
210    pub estimated_remaining: Option<usize>,
211    /// Timestamp when this solution was computed (not serialized)
212    #[cfg(feature = "std")]
213    #[cfg_attr(feature = "serde", serde(skip, default = "std::time::Instant::now"))]
214    pub timestamp: std::time::Instant,
215    #[cfg(not(feature = "std"))]
216    pub timestamp: u64,
217}
218
219/// Core trait for all solver algorithms.
220///
221/// This trait defines the interface that all sublinear-time solvers must implement,
222/// providing both batch and streaming solution capabilities.
223pub trait SolverAlgorithm: Send + Sync {
224    /// Solver-specific state type
225    type State: SolverState;
226
227    /// Initialize the solver state for a given problem.
228    fn initialize(
229        &self,
230        matrix: &dyn Matrix,
231        b: &[Precision],
232        options: &SolverOptions,
233    ) -> Result<Self::State>;
234
235    /// Perform a single iteration step.
236    fn step(&self, state: &mut Self::State) -> Result<StepResult>;
237
238    /// Check if the current state meets convergence criteria.
239    fn is_converged(&self, state: &Self::State) -> bool;
240
241    /// Extract the current solution from the state.
242    fn extract_solution(&self, state: &Self::State) -> Vec<Precision>;
243
244    /// Update the right-hand side for incremental solving.
245    fn update_rhs(&self, state: &mut Self::State, delta_b: &[(usize, Precision)]) -> Result<()>;
246
247    /// Get the algorithm name for identification.
248    fn algorithm_name(&self) -> &'static str;
249
250    /// Solve the linear system Ax = b.
251    ///
252    /// This is the main interface for solving linear systems. It handles
253    /// the iteration loop and convergence checking automatically.
254    fn solve(
255        &self,
256        matrix: &dyn Matrix,
257        b: &[Precision],
258        options: &SolverOptions,
259    ) -> Result<SolverResult> {
260        let mut state = self.initialize(matrix, b, options)?;
261        let mut iterations = 0;
262
263        #[cfg(feature = "std")]
264        let start_time = std::time::Instant::now();
265
266        while !self.is_converged(&state) && iterations < options.max_iterations {
267            match self.step(&mut state)? {
268                StepResult::Continue => {
269                    iterations += 1;
270
271                    // Check for numerical issues
272                    let residual = state.residual_norm();
273                    if !residual.is_finite() {
274                        return Err(SolverError::NumericalInstability {
275                            reason: "Non-finite residual norm".to_string(),
276                            iteration: iterations,
277                            residual_norm: residual,
278                        });
279                    }
280                },
281                StepResult::Converged => break,
282                StepResult::Failed(reason) => {
283                    return Err(SolverError::AlgorithmError {
284                        algorithm: self.algorithm_name().to_string(),
285                        message: reason,
286                        context: vec![
287                            ("iteration".to_string(), iterations.to_string()),
288                            ("residual_norm".to_string(), state.residual_norm().to_string()),
289                        ],
290                    });
291                }
292            }
293        }
294
295        let converged = self.is_converged(&state);
296        let solution = self.extract_solution(&state);
297        let residual_norm = state.residual_norm();
298
299        // Check for convergence failure
300        if !converged && iterations >= options.max_iterations {
301            return Err(SolverError::ConvergenceFailure {
302                iterations,
303                residual_norm,
304                tolerance: options.tolerance,
305                algorithm: self.algorithm_name().to_string(),
306            });
307        }
308
309        let mut result = if converged {
310            SolverResult::success(solution, residual_norm, iterations)
311        } else {
312            SolverResult::failure(solution, residual_norm, iterations)
313        };
314
315        // Add optional data if requested
316        if options.collect_stats {
317            #[cfg(feature = "std")]
318            {
319                let total_time = start_time.elapsed().as_millis() as f64;
320                let mut stats = SolverStats::new();
321                stats.total_time_ms = total_time;
322                stats.matvec_count = state.matvec_count();
323                result.stats = Some(stats);
324            }
325        }
326
327        if options.compute_error_bounds {
328            result.error_bounds = state.error_bounds();
329        }
330
331        Ok(result)
332    }
333}
334
335/// Trait for solver state management.
336pub trait SolverState: Send + Sync {
337    /// Get the current residual norm.
338    fn residual_norm(&self) -> Precision;
339
340    /// Get the number of matrix-vector multiplications performed.
341    fn matvec_count(&self) -> usize;
342
343    /// Get error bounds if available.
344    fn error_bounds(&self) -> Option<ErrorBounds>;
345
346    /// Get current memory usage.
347    fn memory_usage(&self) -> MemoryInfo;
348
349    /// Reset the state for a new solve.
350    fn reset(&mut self);
351}
352
353/// Result of a single iteration step.
354#[derive(Debug, Clone, PartialEq)]
355pub enum StepResult {
356    /// Continue iterating
357    Continue,
358    /// Convergence achieved
359    Converged,
360    /// Algorithm failed with reason
361    Failed(String),
362}
363
364/// Utility functions for solver implementations.
365pub mod utils {
366    use super::*;
367
368    /// Compute the L2 norm of a vector.
369    pub fn l2_norm(v: &[Precision]) -> Precision {
370        v.iter().map(|x| x * x).sum::<Precision>().sqrt()
371    }
372
373    /// Compute the L1 norm of a vector.
374    pub fn l1_norm(v: &[Precision]) -> Precision {
375        v.iter().map(|x| x.abs()).sum()
376    }
377
378    /// Compute the L∞ norm of a vector.
379    pub fn linf_norm(v: &[Precision]) -> Precision {
380        v.iter().map(|x| x.abs()).fold(0.0, Precision::max)
381    }
382
383    /// Compute vector norm according to specified type.
384    pub fn compute_norm(v: &[Precision], norm_type: NormType) -> Precision {
385        match norm_type {
386            NormType::L1 => l1_norm(v),
387            NormType::L2 => l2_norm(v),
388            NormType::LInfinity => linf_norm(v),
389            NormType::Weighted => l2_norm(v), // Default to L2 for weighted
390        }
391    }
392
393    /// Compute residual vector: r = A*x - b
394    pub fn compute_residual(
395        matrix: &dyn Matrix,
396        x: &[Precision],
397        b: &[Precision],
398        residual: &mut [Precision],
399    ) -> Result<()> {
400        matrix.multiply_vector(x, residual)?;
401        for (r, &b_val) in residual.iter_mut().zip(b.iter()) {
402            *r -= b_val;
403        }
404        Ok(())
405    }
406
407    /// Check convergence based on specified criteria.
408    pub fn check_convergence(
409        residual_norm: Precision,
410        tolerance: Precision,
411        mode: ConvergenceMode,
412        b_norm: Precision,
413        prev_solution: Option<&[Precision]>,
414        current_solution: &[Precision],
415    ) -> bool {
416        match mode {
417            ConvergenceMode::ResidualNorm => residual_norm <= tolerance,
418            ConvergenceMode::RelativeResidual => {
419                if b_norm > 0.0 {
420                    (residual_norm / b_norm) <= tolerance
421                } else {
422                    residual_norm <= tolerance
423                }
424            },
425            ConvergenceMode::SolutionChange => {
426                if let Some(prev) = prev_solution {
427                    let mut change_norm = 0.0;
428                    for (&curr, &prev_val) in current_solution.iter().zip(prev.iter()) {
429                        let diff = curr - prev_val;
430                        change_norm += diff * diff;
431                    }
432                    change_norm.sqrt() <= tolerance
433                } else {
434                    false
435                }
436            },
437            ConvergenceMode::RelativeSolutionChange => {
438                if let Some(prev) = prev_solution {
439                    let mut change_norm = 0.0;
440                    let mut solution_norm = 0.0;
441                    for (&curr, &prev_val) in current_solution.iter().zip(prev.iter()) {
442                        let diff = curr - prev_val;
443                        change_norm += diff * diff;
444                        solution_norm += prev_val * prev_val;
445                    }
446                    if solution_norm > 0.0 {
447                        (change_norm.sqrt() / solution_norm.sqrt()) <= tolerance
448                    } else {
449                        change_norm.sqrt() <= tolerance
450                    }
451                } else {
452                    false
453                }
454            },
455            ConvergenceMode::Combined => {
456                // Use the most conservative criterion
457                residual_norm <= tolerance &&
458                (b_norm == 0.0 || (residual_norm / b_norm) <= tolerance)
459            },
460        }
461    }
462}
463
464// Forward declarations for solver implementations that will be added
465pub struct ForwardPushSolver;
466pub struct BackwardPushSolver;
467pub struct HybridSolver;
468
469// Placeholder implementations - will be implemented in separate modules
470impl SolverAlgorithm for ForwardPushSolver {
471    type State = ();
472
473    fn initialize(&self, _matrix: &dyn Matrix, _b: &[Precision], _options: &SolverOptions) -> Result<Self::State> {
474        Err(SolverError::AlgorithmError {
475            algorithm: "forward_push".to_string(),
476            message: "Not implemented yet".to_string(),
477            context: vec![],
478        })
479    }
480
481    fn step(&self, _state: &mut Self::State) -> Result<StepResult> {
482        Err(SolverError::AlgorithmError {
483            algorithm: "forward_push".to_string(),
484            message: "Not implemented yet".to_string(),
485            context: vec![],
486        })
487    }
488
489    fn is_converged(&self, _state: &Self::State) -> bool {
490        false
491    }
492
493    fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> {
494        Vec::new()
495    }
496
497    fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> {
498        Err(SolverError::AlgorithmError {
499            algorithm: "forward_push".to_string(),
500            message: "Not implemented yet".to_string(),
501            context: vec![],
502        })
503    }
504
505    fn algorithm_name(&self) -> &'static str {
506        "forward_push"
507    }
508}
509
510impl SolverState for () {
511    fn residual_norm(&self) -> Precision {
512        0.0
513    }
514
515    fn matvec_count(&self) -> usize {
516        0
517    }
518
519    fn error_bounds(&self) -> Option<ErrorBounds> {
520        None
521    }
522
523    fn memory_usage(&self) -> MemoryInfo {
524        MemoryInfo {
525            current_usage_bytes: 0,
526            peak_usage_bytes: 0,
527            matrix_memory_bytes: 0,
528            vector_memory_bytes: 0,
529            workspace_memory_bytes: 0,
530            allocation_count: 0,
531            deallocation_count: 0,
532        }
533    }
534
535    fn reset(&mut self) {}
536}
537
538// Similar placeholder implementations for BackwardPushSolver and HybridSolver
539impl SolverAlgorithm for BackwardPushSolver {
540    type State = ();
541    fn initialize(&self, _matrix: &dyn Matrix, _b: &[Precision], _options: &SolverOptions) -> Result<Self::State> { Ok(()) }
542    fn step(&self, _state: &mut Self::State) -> Result<StepResult> { Ok(StepResult::Converged) }
543    fn is_converged(&self, _state: &Self::State) -> bool { true }
544    fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> { Vec::new() }
545    fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> { Ok(()) }
546    fn algorithm_name(&self) -> &'static str { "backward_push" }
547}
548
549impl SolverAlgorithm for HybridSolver {
550    type State = ();
551    fn initialize(&self, _matrix: &dyn Matrix, _b: &[Precision], _options: &SolverOptions) -> Result<Self::State> { Ok(()) }
552    fn step(&self, _state: &mut Self::State) -> Result<StepResult> { Ok(StepResult::Converged) }
553    fn is_converged(&self, _state: &Self::State) -> bool { true }
554    fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> { Vec::new() }
555    fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> { Ok(()) }
556    fn algorithm_name(&self) -> &'static str { "hybrid" }
557}
558
559#[cfg(all(test, feature = "std"))]
560mod tests {
561    use super::*;
562    use crate::matrix::SparseMatrix;
563
564    #[test]
565    fn test_solver_options() {
566        let default_opts = SolverOptions::default();
567        assert_eq!(default_opts.tolerance, 1e-6);
568        assert_eq!(default_opts.max_iterations, 1000);
569
570        let fast_opts = SolverOptions::fast();
571        assert_eq!(fast_opts.tolerance, 1e-3);
572        assert_eq!(fast_opts.max_iterations, 100);
573
574        let precision_opts = SolverOptions::high_precision();
575        assert_eq!(precision_opts.tolerance, 1e-12);
576        assert!(precision_opts.compute_error_bounds);
577    }
578
579    #[test]
580    fn test_solver_result() {
581        let result = SolverResult::success(vec![1.0, 2.0], 1e-8, 10);
582        assert!(result.converged);
583        assert!(result.meets_quality_criteria(1e-6));
584        assert!(!result.meets_quality_criteria(1e-10));
585    }
586
587    #[test]
588    fn test_norm_calculations() {
589        use utils::*;
590
591        let v = vec![3.0, 4.0];
592        assert_eq!(l1_norm(&v), 7.0);
593        assert_eq!(l2_norm(&v), 5.0);
594        assert_eq!(linf_norm(&v), 4.0);
595    }
596}