1use crate::matrix::Matrix;
7use crate::types::{
8 Precision, ConvergenceMode, NormType, ErrorBounds, SolverStats,
9 DimensionType, MemoryInfo, ProfileData
10};
11use crate::error::{SolverError, Result};
12use alloc::{vec::Vec, string::String, boxed::Box};
13
14pub mod neumann;
15
16pub use neumann::NeumannSolver;
18
19#[derive(Debug, Clone, PartialEq)]
21#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22pub struct SolverOptions {
23 pub tolerance: Precision,
25 pub max_iterations: usize,
27 pub convergence_mode: ConvergenceMode,
29 pub norm_type: NormType,
31 pub collect_stats: bool,
33 pub streaming_interval: usize,
35 pub initial_guess: Option<Vec<Precision>>,
37 pub compute_error_bounds: bool,
39 pub error_bounds_tolerance: Precision,
41 pub enable_profiling: bool,
43 pub random_seed: Option<u64>,
45}
46
47impl Default for SolverOptions {
48 fn default() -> Self {
49 Self {
50 tolerance: 1e-6,
51 max_iterations: 1000,
52 convergence_mode: ConvergenceMode::ResidualNorm,
53 norm_type: NormType::L2,
54 collect_stats: false,
55 streaming_interval: 0,
56 initial_guess: None,
57 compute_error_bounds: false,
58 error_bounds_tolerance: 1e-8,
59 enable_profiling: false,
60 random_seed: None,
61 }
62 }
63}
64
65impl SolverOptions {
66 pub fn high_precision() -> Self {
68 Self {
69 tolerance: 1e-12,
70 max_iterations: 5000,
71 convergence_mode: ConvergenceMode::Combined,
72 norm_type: NormType::L2,
73 collect_stats: true,
74 streaming_interval: 0,
75 initial_guess: None,
76 compute_error_bounds: true,
77 error_bounds_tolerance: 1e-14,
78 enable_profiling: false,
79 random_seed: None,
80 }
81 }
82
83 pub fn fast() -> Self {
85 Self {
86 tolerance: 1e-3,
87 max_iterations: 100,
88 convergence_mode: ConvergenceMode::ResidualNorm,
89 norm_type: NormType::L2,
90 collect_stats: false,
91 streaming_interval: 0,
92 initial_guess: None,
93 compute_error_bounds: false,
94 error_bounds_tolerance: 1e-4,
95 enable_profiling: false,
96 random_seed: None,
97 }
98 }
99
100 pub fn streaming(interval: usize) -> Self {
102 Self {
103 tolerance: 1e-4,
104 max_iterations: 1000,
105 convergence_mode: ConvergenceMode::ResidualNorm,
106 norm_type: NormType::L2,
107 collect_stats: true,
108 streaming_interval: interval,
109 initial_guess: None,
110 compute_error_bounds: false,
111 error_bounds_tolerance: 1e-6,
112 enable_profiling: true,
113 random_seed: None,
114 }
115 }
116}
117
118#[derive(Debug, Clone, PartialEq)]
120#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
121pub struct SolverResult {
122 pub solution: Vec<Precision>,
124 pub residual_norm: Precision,
126 pub iterations: usize,
128 pub converged: bool,
130 pub error_bounds: Option<ErrorBounds>,
132 pub stats: Option<SolverStats>,
134 pub memory_info: Option<MemoryInfo>,
136 pub profile_data: Option<Vec<ProfileData>>,
138}
139
140impl SolverResult {
141 pub fn success(
143 solution: Vec<Precision>,
144 residual_norm: Precision,
145 iterations: usize,
146 ) -> Self {
147 Self {
148 solution,
149 residual_norm,
150 iterations,
151 converged: true,
152 error_bounds: None,
153 stats: None,
154 memory_info: None,
155 profile_data: None,
156 }
157 }
158
159 pub fn failure(
161 solution: Vec<Precision>,
162 residual_norm: Precision,
163 iterations: usize,
164 ) -> Self {
165 Self {
166 solution,
167 residual_norm,
168 iterations,
169 converged: false,
170 error_bounds: None,
171 stats: None,
172 memory_info: None,
173 profile_data: None,
174 }
175 }
176
177 pub fn error(error: SolverError) -> Self {
179 Self {
180 solution: Vec::new(),
181 residual_norm: Precision::INFINITY,
182 iterations: 0,
183 converged: false,
184 error_bounds: None,
185 stats: None,
186 memory_info: None,
187 profile_data: None,
188 }
189 }
190
191 pub fn meets_quality_criteria(&self, tolerance: Precision) -> bool {
193 self.converged && self.residual_norm <= tolerance
194 }
195}
196
197#[derive(Debug, Clone, PartialEq)]
199#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
200pub struct PartialSolution {
201 pub iteration: usize,
203 pub solution: Vec<Precision>,
205 pub residual_norm: Precision,
207 pub converged: bool,
209 pub estimated_remaining: Option<usize>,
211 #[cfg(feature = "std")]
213 #[cfg_attr(feature = "serde", serde(skip, default = "std::time::Instant::now"))]
214 pub timestamp: std::time::Instant,
215 #[cfg(not(feature = "std"))]
216 pub timestamp: u64,
217}
218
219pub trait SolverAlgorithm: Send + Sync {
224 type State: SolverState;
226
227 fn initialize(
229 &self,
230 matrix: &dyn Matrix,
231 b: &[Precision],
232 options: &SolverOptions,
233 ) -> Result<Self::State>;
234
235 fn step(&self, state: &mut Self::State) -> Result<StepResult>;
237
238 fn is_converged(&self, state: &Self::State) -> bool;
240
241 fn extract_solution(&self, state: &Self::State) -> Vec<Precision>;
243
244 fn update_rhs(&self, state: &mut Self::State, delta_b: &[(usize, Precision)]) -> Result<()>;
246
247 fn algorithm_name(&self) -> &'static str;
249
250 fn solve(
255 &self,
256 matrix: &dyn Matrix,
257 b: &[Precision],
258 options: &SolverOptions,
259 ) -> Result<SolverResult> {
260 let mut state = self.initialize(matrix, b, options)?;
261 let mut iterations = 0;
262
263 #[cfg(feature = "std")]
264 let start_time = std::time::Instant::now();
265
266 while !self.is_converged(&state) && iterations < options.max_iterations {
267 match self.step(&mut state)? {
268 StepResult::Continue => {
269 iterations += 1;
270
271 let residual = state.residual_norm();
273 if !residual.is_finite() {
274 return Err(SolverError::NumericalInstability {
275 reason: "Non-finite residual norm".to_string(),
276 iteration: iterations,
277 residual_norm: residual,
278 });
279 }
280 },
281 StepResult::Converged => break,
282 StepResult::Failed(reason) => {
283 return Err(SolverError::AlgorithmError {
284 algorithm: self.algorithm_name().to_string(),
285 message: reason,
286 context: vec![
287 ("iteration".to_string(), iterations.to_string()),
288 ("residual_norm".to_string(), state.residual_norm().to_string()),
289 ],
290 });
291 }
292 }
293 }
294
295 let converged = self.is_converged(&state);
296 let solution = self.extract_solution(&state);
297 let residual_norm = state.residual_norm();
298
299 if !converged && iterations >= options.max_iterations {
301 return Err(SolverError::ConvergenceFailure {
302 iterations,
303 residual_norm,
304 tolerance: options.tolerance,
305 algorithm: self.algorithm_name().to_string(),
306 });
307 }
308
309 let mut result = if converged {
310 SolverResult::success(solution, residual_norm, iterations)
311 } else {
312 SolverResult::failure(solution, residual_norm, iterations)
313 };
314
315 if options.collect_stats {
317 #[cfg(feature = "std")]
318 {
319 let total_time = start_time.elapsed().as_millis() as f64;
320 let mut stats = SolverStats::new();
321 stats.total_time_ms = total_time;
322 stats.matvec_count = state.matvec_count();
323 result.stats = Some(stats);
324 }
325 }
326
327 if options.compute_error_bounds {
328 result.error_bounds = state.error_bounds();
329 }
330
331 Ok(result)
332 }
333}
334
335pub trait SolverState: Send + Sync {
337 fn residual_norm(&self) -> Precision;
339
340 fn matvec_count(&self) -> usize;
342
343 fn error_bounds(&self) -> Option<ErrorBounds>;
345
346 fn memory_usage(&self) -> MemoryInfo;
348
349 fn reset(&mut self);
351}
352
353#[derive(Debug, Clone, PartialEq)]
355pub enum StepResult {
356 Continue,
358 Converged,
360 Failed(String),
362}
363
364pub mod utils {
366 use super::*;
367
368 pub fn l2_norm(v: &[Precision]) -> Precision {
370 v.iter().map(|x| x * x).sum::<Precision>().sqrt()
371 }
372
373 pub fn l1_norm(v: &[Precision]) -> Precision {
375 v.iter().map(|x| x.abs()).sum()
376 }
377
378 pub fn linf_norm(v: &[Precision]) -> Precision {
380 v.iter().map(|x| x.abs()).fold(0.0, Precision::max)
381 }
382
383 pub fn compute_norm(v: &[Precision], norm_type: NormType) -> Precision {
385 match norm_type {
386 NormType::L1 => l1_norm(v),
387 NormType::L2 => l2_norm(v),
388 NormType::LInfinity => linf_norm(v),
389 NormType::Weighted => l2_norm(v), }
391 }
392
393 pub fn compute_residual(
395 matrix: &dyn Matrix,
396 x: &[Precision],
397 b: &[Precision],
398 residual: &mut [Precision],
399 ) -> Result<()> {
400 matrix.multiply_vector(x, residual)?;
401 for (r, &b_val) in residual.iter_mut().zip(b.iter()) {
402 *r -= b_val;
403 }
404 Ok(())
405 }
406
407 pub fn check_convergence(
409 residual_norm: Precision,
410 tolerance: Precision,
411 mode: ConvergenceMode,
412 b_norm: Precision,
413 prev_solution: Option<&[Precision]>,
414 current_solution: &[Precision],
415 ) -> bool {
416 match mode {
417 ConvergenceMode::ResidualNorm => residual_norm <= tolerance,
418 ConvergenceMode::RelativeResidual => {
419 if b_norm > 0.0 {
420 (residual_norm / b_norm) <= tolerance
421 } else {
422 residual_norm <= tolerance
423 }
424 },
425 ConvergenceMode::SolutionChange => {
426 if let Some(prev) = prev_solution {
427 let mut change_norm = 0.0;
428 for (&curr, &prev_val) in current_solution.iter().zip(prev.iter()) {
429 let diff = curr - prev_val;
430 change_norm += diff * diff;
431 }
432 change_norm.sqrt() <= tolerance
433 } else {
434 false
435 }
436 },
437 ConvergenceMode::RelativeSolutionChange => {
438 if let Some(prev) = prev_solution {
439 let mut change_norm = 0.0;
440 let mut solution_norm = 0.0;
441 for (&curr, &prev_val) in current_solution.iter().zip(prev.iter()) {
442 let diff = curr - prev_val;
443 change_norm += diff * diff;
444 solution_norm += prev_val * prev_val;
445 }
446 if solution_norm > 0.0 {
447 (change_norm.sqrt() / solution_norm.sqrt()) <= tolerance
448 } else {
449 change_norm.sqrt() <= tolerance
450 }
451 } else {
452 false
453 }
454 },
455 ConvergenceMode::Combined => {
456 residual_norm <= tolerance &&
458 (b_norm == 0.0 || (residual_norm / b_norm) <= tolerance)
459 },
460 }
461 }
462}
463
464pub struct ForwardPushSolver;
466pub struct BackwardPushSolver;
467pub struct HybridSolver;
468
469impl SolverAlgorithm for ForwardPushSolver {
471 type State = ();
472
473 fn initialize(&self, _matrix: &dyn Matrix, _b: &[Precision], _options: &SolverOptions) -> Result<Self::State> {
474 Err(SolverError::AlgorithmError {
475 algorithm: "forward_push".to_string(),
476 message: "Not implemented yet".to_string(),
477 context: vec![],
478 })
479 }
480
481 fn step(&self, _state: &mut Self::State) -> Result<StepResult> {
482 Err(SolverError::AlgorithmError {
483 algorithm: "forward_push".to_string(),
484 message: "Not implemented yet".to_string(),
485 context: vec![],
486 })
487 }
488
489 fn is_converged(&self, _state: &Self::State) -> bool {
490 false
491 }
492
493 fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> {
494 Vec::new()
495 }
496
497 fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> {
498 Err(SolverError::AlgorithmError {
499 algorithm: "forward_push".to_string(),
500 message: "Not implemented yet".to_string(),
501 context: vec![],
502 })
503 }
504
505 fn algorithm_name(&self) -> &'static str {
506 "forward_push"
507 }
508}
509
510impl SolverState for () {
511 fn residual_norm(&self) -> Precision {
512 0.0
513 }
514
515 fn matvec_count(&self) -> usize {
516 0
517 }
518
519 fn error_bounds(&self) -> Option<ErrorBounds> {
520 None
521 }
522
523 fn memory_usage(&self) -> MemoryInfo {
524 MemoryInfo {
525 current_usage_bytes: 0,
526 peak_usage_bytes: 0,
527 matrix_memory_bytes: 0,
528 vector_memory_bytes: 0,
529 workspace_memory_bytes: 0,
530 allocation_count: 0,
531 deallocation_count: 0,
532 }
533 }
534
535 fn reset(&mut self) {}
536}
537
538impl SolverAlgorithm for BackwardPushSolver {
540 type State = ();
541 fn initialize(&self, _matrix: &dyn Matrix, _b: &[Precision], _options: &SolverOptions) -> Result<Self::State> { Ok(()) }
542 fn step(&self, _state: &mut Self::State) -> Result<StepResult> { Ok(StepResult::Converged) }
543 fn is_converged(&self, _state: &Self::State) -> bool { true }
544 fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> { Vec::new() }
545 fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> { Ok(()) }
546 fn algorithm_name(&self) -> &'static str { "backward_push" }
547}
548
549impl SolverAlgorithm for HybridSolver {
550 type State = ();
551 fn initialize(&self, _matrix: &dyn Matrix, _b: &[Precision], _options: &SolverOptions) -> Result<Self::State> { Ok(()) }
552 fn step(&self, _state: &mut Self::State) -> Result<StepResult> { Ok(StepResult::Converged) }
553 fn is_converged(&self, _state: &Self::State) -> bool { true }
554 fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> { Vec::new() }
555 fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> { Ok(()) }
556 fn algorithm_name(&self) -> &'static str { "hybrid" }
557}
558
559#[cfg(all(test, feature = "std"))]
560mod tests {
561 use super::*;
562 use crate::matrix::SparseMatrix;
563
564 #[test]
565 fn test_solver_options() {
566 let default_opts = SolverOptions::default();
567 assert_eq!(default_opts.tolerance, 1e-6);
568 assert_eq!(default_opts.max_iterations, 1000);
569
570 let fast_opts = SolverOptions::fast();
571 assert_eq!(fast_opts.tolerance, 1e-3);
572 assert_eq!(fast_opts.max_iterations, 100);
573
574 let precision_opts = SolverOptions::high_precision();
575 assert_eq!(precision_opts.tolerance, 1e-12);
576 assert!(precision_opts.compute_error_bounds);
577 }
578
579 #[test]
580 fn test_solver_result() {
581 let result = SolverResult::success(vec![1.0, 2.0], 1e-8, 10);
582 assert!(result.converged);
583 assert!(result.meets_quality_criteria(1e-6));
584 assert!(!result.meets_quality_criteria(1e-10));
585 }
586
587 #[test]
588 fn test_norm_calculations() {
589 use utils::*;
590
591 let v = vec![3.0, 4.0];
592 assert_eq!(l1_norm(&v), 7.0);
593 assert_eq!(l2_norm(&v), 5.0);
594 assert_eq!(linf_norm(&v), 4.0);
595 }
596}