1use crate::error::{Result, SolverError};
7use crate::matrix::Matrix;
8use crate::types::{
9 ConvergenceMode, ErrorBounds, MemoryInfo, NormType, Precision, ProfileData, SolverStats,
10};
11use alloc::{string::String, vec::Vec};
12
13pub mod neumann;
14
15pub use neumann::NeumannSolver;
17
18#[derive(Debug, Clone, PartialEq)]
20#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
21pub struct SolverOptions {
22 pub tolerance: Precision,
24 pub max_iterations: usize,
26 pub convergence_mode: ConvergenceMode,
28 pub norm_type: NormType,
30 pub collect_stats: bool,
32 pub streaming_interval: usize,
34 pub initial_guess: Option<Vec<Precision>>,
36 pub compute_error_bounds: bool,
38 pub error_bounds_tolerance: Precision,
40 pub enable_profiling: bool,
42 pub random_seed: Option<u64>,
44 pub coherence_threshold: Precision,
55}
56
57impl Default for SolverOptions {
58 fn default() -> Self {
59 Self {
60 tolerance: 1e-6,
61 max_iterations: 1000,
62 convergence_mode: ConvergenceMode::ResidualNorm,
63 norm_type: NormType::L2,
64 collect_stats: false,
65 streaming_interval: 0,
66 initial_guess: None,
67 compute_error_bounds: false,
68 error_bounds_tolerance: 1e-8,
69 enable_profiling: false,
70 random_seed: None,
71 coherence_threshold: 0.0,
73 }
74 }
75}
76
77impl SolverOptions {
78 pub fn high_precision() -> Self {
80 Self {
81 tolerance: 1e-12,
82 max_iterations: 5000,
83 convergence_mode: ConvergenceMode::Combined,
84 norm_type: NormType::L2,
85 collect_stats: true,
86 streaming_interval: 0,
87 initial_guess: None,
88 compute_error_bounds: true,
89 error_bounds_tolerance: 1e-14,
90 enable_profiling: false,
91 random_seed: None,
92 coherence_threshold: 0.0,
93 }
94 }
95
96 pub fn fast() -> Self {
98 Self {
99 tolerance: 1e-3,
100 max_iterations: 100,
101 convergence_mode: ConvergenceMode::ResidualNorm,
102 norm_type: NormType::L2,
103 collect_stats: false,
104 streaming_interval: 0,
105 initial_guess: None,
106 compute_error_bounds: false,
107 error_bounds_tolerance: 1e-4,
108 enable_profiling: false,
109 random_seed: None,
110 coherence_threshold: 0.0,
111 }
112 }
113
114 pub fn streaming(interval: usize) -> Self {
116 Self {
117 tolerance: 1e-4,
118 max_iterations: 1000,
119 convergence_mode: ConvergenceMode::ResidualNorm,
120 norm_type: NormType::L2,
121 collect_stats: true,
122 streaming_interval: interval,
123 initial_guess: None,
124 compute_error_bounds: false,
125 error_bounds_tolerance: 1e-6,
126 enable_profiling: true,
127 random_seed: None,
128 coherence_threshold: 0.0,
129 }
130 }
131}
132
133#[derive(Debug, Clone, PartialEq)]
135#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
136pub struct SolverResult {
137 pub solution: Vec<Precision>,
139 pub residual_norm: Precision,
141 pub iterations: usize,
143 pub converged: bool,
145 pub error_bounds: Option<ErrorBounds>,
147 pub stats: Option<SolverStats>,
149 pub memory_info: Option<MemoryInfo>,
151 pub profile_data: Option<Vec<ProfileData>>,
153}
154
155impl SolverResult {
156 pub fn success(solution: Vec<Precision>, residual_norm: Precision, iterations: usize) -> Self {
158 Self {
159 solution,
160 residual_norm,
161 iterations,
162 converged: true,
163 error_bounds: None,
164 stats: None,
165 memory_info: None,
166 profile_data: None,
167 }
168 }
169
170 pub fn failure(solution: Vec<Precision>, residual_norm: Precision, iterations: usize) -> Self {
172 Self {
173 solution,
174 residual_norm,
175 iterations,
176 converged: false,
177 error_bounds: None,
178 stats: None,
179 memory_info: None,
180 profile_data: None,
181 }
182 }
183
184 pub fn error(error: SolverError) -> Self {
186 Self {
187 solution: Vec::new(),
188 residual_norm: Precision::INFINITY,
189 iterations: 0,
190 converged: false,
191 error_bounds: None,
192 stats: None,
193 memory_info: None,
194 profile_data: None,
195 }
196 }
197
198 pub fn meets_quality_criteria(&self, tolerance: Precision) -> bool {
200 self.converged && self.residual_norm <= tolerance
201 }
202}
203
204#[derive(Debug, Clone, PartialEq)]
206#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
207pub struct PartialSolution {
208 pub iteration: usize,
210 pub solution: Vec<Precision>,
212 pub residual_norm: Precision,
214 pub converged: bool,
216 pub estimated_remaining: Option<usize>,
218 #[cfg(feature = "std")]
220 #[cfg_attr(feature = "serde", serde(skip, default = "std::time::Instant::now"))]
221 pub timestamp: std::time::Instant,
222 #[cfg(not(feature = "std"))]
223 pub timestamp: u64,
224}
225
226pub trait SolverAlgorithm: Send + Sync {
231 type State: SolverState;
233
234 fn initialize(
236 &self,
237 matrix: &dyn Matrix,
238 b: &[Precision],
239 options: &SolverOptions,
240 ) -> Result<Self::State>;
241
242 fn step(&self, state: &mut Self::State) -> Result<StepResult>;
244
245 fn is_converged(&self, state: &Self::State) -> bool;
247
248 fn extract_solution(&self, state: &Self::State) -> Vec<Precision>;
250
251 fn update_rhs(&self, state: &mut Self::State, delta_b: &[(usize, Precision)]) -> Result<()>;
253
254 fn algorithm_name(&self) -> &'static str;
256
257 fn solve(
262 &self,
263 matrix: &dyn Matrix,
264 b: &[Precision],
265 options: &SolverOptions,
266 ) -> Result<SolverResult> {
267 let mut state = self.initialize(matrix, b, options)?;
268 let mut iterations = 0;
269
270 #[cfg(feature = "std")]
271 let start_time = std::time::Instant::now();
272
273 while !self.is_converged(&state) && iterations < options.max_iterations {
274 match self.step(&mut state)? {
275 StepResult::Continue => {
276 iterations += 1;
277
278 let residual = state.residual_norm();
280 if !residual.is_finite() {
281 return Err(SolverError::NumericalInstability {
282 reason: "Non-finite residual norm".to_string(),
283 iteration: iterations,
284 residual_norm: residual,
285 });
286 }
287 }
288 StepResult::Converged => break,
289 StepResult::Failed(reason) => {
290 return Err(SolverError::AlgorithmError {
291 algorithm: self.algorithm_name().to_string(),
292 message: reason,
293 context: vec![
294 ("iteration".to_string(), iterations.to_string()),
295 (
296 "residual_norm".to_string(),
297 state.residual_norm().to_string(),
298 ),
299 ],
300 });
301 }
302 }
303 }
304
305 let converged = self.is_converged(&state);
306 let solution = self.extract_solution(&state);
307 let residual_norm = state.residual_norm();
308
309 if !converged && iterations >= options.max_iterations {
311 return Err(SolverError::ConvergenceFailure {
312 iterations,
313 residual_norm,
314 tolerance: options.tolerance,
315 algorithm: self.algorithm_name().to_string(),
316 });
317 }
318
319 let mut result = if converged {
320 SolverResult::success(solution, residual_norm, iterations)
321 } else {
322 SolverResult::failure(solution, residual_norm, iterations)
323 };
324
325 if options.collect_stats {
327 #[cfg(feature = "std")]
328 {
329 let total_time = start_time.elapsed().as_millis() as f64;
330 let mut stats = SolverStats::new();
331 stats.total_time_ms = total_time;
332 stats.matvec_count = state.matvec_count();
333 result.stats = Some(stats);
334 }
335 }
336
337 if options.compute_error_bounds {
338 result.error_bounds = state.error_bounds();
339 }
340
341 Ok(result)
342 }
343}
344
345pub trait SolverState: Send + Sync {
347 fn residual_norm(&self) -> Precision;
349
350 fn matvec_count(&self) -> usize;
352
353 fn error_bounds(&self) -> Option<ErrorBounds>;
355
356 fn memory_usage(&self) -> MemoryInfo;
358
359 fn reset(&mut self);
361}
362
363#[derive(Debug, Clone, PartialEq)]
365pub enum StepResult {
366 Continue,
368 Converged,
370 Failed(String),
372}
373
374pub mod utils {
376 use super::*;
377
378 pub fn l2_norm(v: &[Precision]) -> Precision {
380 v.iter().map(|x| x * x).sum::<Precision>().sqrt()
381 }
382
383 pub fn l1_norm(v: &[Precision]) -> Precision {
385 v.iter().map(|x| x.abs()).sum()
386 }
387
388 pub fn linf_norm(v: &[Precision]) -> Precision {
390 v.iter().map(|x| x.abs()).fold(0.0, Precision::max)
391 }
392
393 pub fn compute_norm(v: &[Precision], norm_type: NormType) -> Precision {
395 match norm_type {
396 NormType::L1 => l1_norm(v),
397 NormType::L2 => l2_norm(v),
398 NormType::LInfinity => linf_norm(v),
399 NormType::Weighted => l2_norm(v), }
401 }
402
403 pub fn compute_residual(
405 matrix: &dyn Matrix,
406 x: &[Precision],
407 b: &[Precision],
408 residual: &mut [Precision],
409 ) -> Result<()> {
410 matrix.multiply_vector(x, residual)?;
411 for (r, &b_val) in residual.iter_mut().zip(b.iter()) {
412 *r -= b_val;
413 }
414 Ok(())
415 }
416
417 pub fn check_convergence(
419 residual_norm: Precision,
420 tolerance: Precision,
421 mode: ConvergenceMode,
422 b_norm: Precision,
423 prev_solution: Option<&[Precision]>,
424 current_solution: &[Precision],
425 ) -> bool {
426 match mode {
427 ConvergenceMode::ResidualNorm => residual_norm <= tolerance,
428 ConvergenceMode::RelativeResidual => {
429 if b_norm > 0.0 {
430 (residual_norm / b_norm) <= tolerance
431 } else {
432 residual_norm <= tolerance
433 }
434 }
435 ConvergenceMode::SolutionChange => {
436 if let Some(prev) = prev_solution {
437 let mut change_norm = 0.0;
438 for (&curr, &prev_val) in current_solution.iter().zip(prev.iter()) {
439 let diff = curr - prev_val;
440 change_norm += diff * diff;
441 }
442 change_norm.sqrt() <= tolerance
443 } else {
444 false
445 }
446 }
447 ConvergenceMode::RelativeSolutionChange => {
448 if let Some(prev) = prev_solution {
449 let mut change_norm = 0.0;
450 let mut solution_norm = 0.0;
451 for (&curr, &prev_val) in current_solution.iter().zip(prev.iter()) {
452 let diff = curr - prev_val;
453 change_norm += diff * diff;
454 solution_norm += prev_val * prev_val;
455 }
456 if solution_norm > 0.0 {
457 (change_norm.sqrt() / solution_norm.sqrt()) <= tolerance
458 } else {
459 change_norm.sqrt() <= tolerance
460 }
461 } else {
462 false
463 }
464 }
465 ConvergenceMode::Combined => {
466 residual_norm <= tolerance
468 && (b_norm == 0.0 || (residual_norm / b_norm) <= tolerance)
469 }
470 }
471 }
472}
473
474pub struct ForwardPushSolver;
476pub struct BackwardPushSolver;
477pub struct HybridSolver;
478
479impl SolverAlgorithm for ForwardPushSolver {
481 type State = ();
482
483 fn initialize(
484 &self,
485 _matrix: &dyn Matrix,
486 _b: &[Precision],
487 _options: &SolverOptions,
488 ) -> Result<Self::State> {
489 Err(SolverError::AlgorithmError {
490 algorithm: "forward_push".to_string(),
491 message: "Not implemented yet".to_string(),
492 context: vec![],
493 })
494 }
495
496 fn step(&self, _state: &mut Self::State) -> Result<StepResult> {
497 Err(SolverError::AlgorithmError {
498 algorithm: "forward_push".to_string(),
499 message: "Not implemented yet".to_string(),
500 context: vec![],
501 })
502 }
503
504 fn is_converged(&self, _state: &Self::State) -> bool {
505 false
506 }
507
508 fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> {
509 Vec::new()
510 }
511
512 fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> {
513 Err(SolverError::AlgorithmError {
514 algorithm: "forward_push".to_string(),
515 message: "Not implemented yet".to_string(),
516 context: vec![],
517 })
518 }
519
520 fn algorithm_name(&self) -> &'static str {
521 "forward_push"
522 }
523}
524
525impl SolverState for () {
526 fn residual_norm(&self) -> Precision {
527 0.0
528 }
529
530 fn matvec_count(&self) -> usize {
531 0
532 }
533
534 fn error_bounds(&self) -> Option<ErrorBounds> {
535 None
536 }
537
538 fn memory_usage(&self) -> MemoryInfo {
539 MemoryInfo {
540 current_usage_bytes: 0,
541 peak_usage_bytes: 0,
542 matrix_memory_bytes: 0,
543 vector_memory_bytes: 0,
544 workspace_memory_bytes: 0,
545 allocation_count: 0,
546 deallocation_count: 0,
547 }
548 }
549
550 fn reset(&mut self) {}
551}
552
553impl SolverAlgorithm for BackwardPushSolver {
555 type State = ();
556 fn initialize(
557 &self,
558 _matrix: &dyn Matrix,
559 _b: &[Precision],
560 _options: &SolverOptions,
561 ) -> Result<Self::State> {
562 Ok(())
563 }
564 fn step(&self, _state: &mut Self::State) -> Result<StepResult> {
565 Ok(StepResult::Converged)
566 }
567 fn is_converged(&self, _state: &Self::State) -> bool {
568 true
569 }
570 fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> {
571 Vec::new()
572 }
573 fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> {
574 Ok(())
575 }
576 fn algorithm_name(&self) -> &'static str {
577 "backward_push"
578 }
579}
580
581impl SolverAlgorithm for HybridSolver {
582 type State = ();
583 fn initialize(
584 &self,
585 _matrix: &dyn Matrix,
586 _b: &[Precision],
587 _options: &SolverOptions,
588 ) -> Result<Self::State> {
589 Ok(())
590 }
591 fn step(&self, _state: &mut Self::State) -> Result<StepResult> {
592 Ok(StepResult::Converged)
593 }
594 fn is_converged(&self, _state: &Self::State) -> bool {
595 true
596 }
597 fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> {
598 Vec::new()
599 }
600 fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> {
601 Ok(())
602 }
603 fn algorithm_name(&self) -> &'static str {
604 "hybrid"
605 }
606}
607
608#[cfg(all(test, feature = "std"))]
609mod tests {
610 use super::*;
611 use crate::matrix::SparseMatrix;
612
613 #[test]
614 fn test_solver_options() {
615 let default_opts = SolverOptions::default();
616 assert_eq!(default_opts.tolerance, 1e-6);
617 assert_eq!(default_opts.max_iterations, 1000);
618
619 let fast_opts = SolverOptions::fast();
620 assert_eq!(fast_opts.tolerance, 1e-3);
621 assert_eq!(fast_opts.max_iterations, 100);
622
623 let precision_opts = SolverOptions::high_precision();
624 assert_eq!(precision_opts.tolerance, 1e-12);
625 assert!(precision_opts.compute_error_bounds);
626 }
627
628 #[test]
629 fn test_solver_result() {
630 let result = SolverResult::success(vec![1.0, 2.0], 1e-8, 10);
631 assert!(result.converged);
632 assert!(result.meets_quality_criteria(1e-6));
633 assert!(!result.meets_quality_criteria(1e-10));
634 }
635
636 #[test]
637 fn test_norm_calculations() {
638 use utils::*;
639
640 let v = vec![3.0, 4.0];
641 assert_eq!(l1_norm(&v), 7.0);
642 assert_eq!(l2_norm(&v), 5.0);
643 assert_eq!(linf_norm(&v), 4.0);
644 }
645}