1use crate::matrix::Matrix;
7use crate::types::{Precision, ErrorBounds, ErrorBoundMethod, MemoryInfo};
8use crate::error::{SolverError, Result};
9use crate::solver::{
10 SolverAlgorithm, SolverState, SolverOptions, StepResult, utils
11};
12use alloc::{vec::Vec, string::String};
13
14#[derive(Debug, Clone)]
24pub struct NeumannSolver {
25 max_terms: usize,
27 series_tolerance: Precision,
29 adaptive_truncation: bool,
31 cache_powers: bool,
33}
34
35impl NeumannSolver {
36 pub fn new(max_terms: usize, series_tolerance: Precision) -> Self {
49 Self {
50 max_terms,
51 series_tolerance,
52 adaptive_truncation: true,
53 cache_powers: true,
54 }
55 }
56
57 pub fn default() -> Self {
59 Self::new(50, 1e-8)
60 }
61
62 pub fn high_precision() -> Self {
64 Self {
65 max_terms: 100,
66 series_tolerance: 1e-12,
67 adaptive_truncation: true,
68 cache_powers: true,
69 }
70 }
71
72 pub fn fast() -> Self {
74 Self {
75 max_terms: 20,
76 series_tolerance: 1e-6,
77 adaptive_truncation: false,
78 cache_powers: false,
79 }
80 }
81
82 pub fn with_adaptive_truncation(mut self, enable: bool) -> Self {
84 self.adaptive_truncation = enable;
85 self
86 }
87
88 pub fn with_power_caching(mut self, enable: bool) -> Self {
90 self.cache_powers = enable;
91 self
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct NeumannState {
98 dimension: usize,
100 solution: Vec<Precision>,
102 rhs: Vec<Precision>,
104 original_rhs: Vec<Precision>,
111 residual: Vec<Precision>,
113 residual_norm: Precision,
115 diagonal_inv: Vec<Precision>,
117 #[allow(dead_code)]
119 iteration_matrix: Option<Vec<Vec<Precision>>>,
120 matrix_powers: Vec<Vec<Precision>>,
122 current_term: Vec<Precision>,
124 terms_computed: usize,
126 matvec_count: usize,
128 previous_solution: Option<Vec<Precision>>,
130 series_converged: bool,
132 error_bounds: Option<ErrorBounds>,
134 memory_usage: MemoryInfo,
136 tolerance: Precision,
138 max_terms: usize,
140 series_tolerance: Precision,
142}
143
144impl NeumannState {
145 fn new(
147 matrix: &dyn Matrix,
148 b: &[Precision],
149 options: &SolverOptions,
150 solver_config: &NeumannSolver,
151 ) -> Result<Self> {
152 let dimension = matrix.rows();
153
154 if !matrix.is_square() {
155 return Err(SolverError::InvalidInput {
156 message: "Matrix must be square for Neumann series".to_string(),
157 parameter: Some("matrix_dimensions".to_string()),
158 });
159 }
160
161 if b.len() != dimension {
162 return Err(SolverError::DimensionMismatch {
163 expected: dimension,
164 actual: b.len(),
165 operation: "neumann_initialization".to_string(),
166 });
167 }
168
169 if !matrix.is_diagonally_dominant() {
171 return Err(SolverError::MatrixNotDiagonallyDominant {
172 row: 0, diagonal: 0.0,
174 off_diagonal_sum: 0.0,
175 });
176 }
177
178 let mut diagonal_inv = vec![0.0; dimension];
180 for i in 0..dimension {
181 if let Some(diag_val) = matrix.get(i, i) {
182 if diag_val.abs() < 1e-14 {
183 return Err(SolverError::InvalidSparseMatrix {
184 reason: format!("Zero or near-zero diagonal element at position {}", i),
185 position: Some((i, i)),
186 });
187 }
188 diagonal_inv[i] = 1.0 / diag_val;
189 } else {
190 return Err(SolverError::InvalidSparseMatrix {
191 reason: format!("Missing diagonal element at position {}", i),
192 position: Some((i, i)),
193 });
194 }
195 }
196
197 let rhs: Vec<Precision> = b.iter()
199 .zip(&diagonal_inv)
200 .map(|(&b_val, &d_inv)| b_val * d_inv)
201 .collect();
202
203 let solution = if let Some(ref initial) = options.initial_guess {
209 if initial.len() != dimension {
210 return Err(SolverError::DimensionMismatch {
211 expected: dimension,
212 actual: initial.len(),
213 operation: "initial_guess".to_string(),
214 });
215 }
216 initial.clone()
217 } else {
218 vec![0.0; dimension]
219 };
220
221 let residual = vec![0.0; dimension];
222 let current_term = rhs.clone();
223
224 let matrix_powers = if solver_config.cache_powers {
225 Vec::with_capacity(solver_config.max_terms)
226 } else {
227 Vec::new()
228 };
229
230 let memory_usage = MemoryInfo {
231 current_usage_bytes: dimension * 8 * 5, peak_usage_bytes: dimension * 8 * 5,
233 matrix_memory_bytes: 0, vector_memory_bytes: dimension * 8 * 5,
235 workspace_memory_bytes: 0,
236 allocation_count: 5,
237 deallocation_count: 0,
238 };
239
240 Ok(Self {
241 dimension,
242 solution,
243 rhs,
244 original_rhs: b.to_vec(),
245 residual,
246 residual_norm: Precision::INFINITY,
247 diagonal_inv,
248 iteration_matrix: None,
249 matrix_powers,
250 current_term,
251 terms_computed: 0,
252 matvec_count: 0,
253 previous_solution: None,
254 series_converged: false,
255 error_bounds: None,
256 memory_usage,
257 tolerance: options.tolerance,
258 max_terms: solver_config.max_terms,
259 series_tolerance: solver_config.series_tolerance,
260 })
261 }
262
263 fn compute_next_term(&mut self, matrix: &dyn Matrix) -> Result<()> {
265 if self.terms_computed >= self.max_terms {
266 return Ok(());
267 }
268
269 if self.terms_computed > 0 {
272 self.apply_iteration_matrix(matrix)?;
273 }
274
275 for (sol, &term) in self.solution.iter_mut().zip(&self.current_term) {
277 *sol += term;
278 }
279
280 self.terms_computed += 1;
281
282 let term_norm = utils::l2_norm(&self.current_term);
284 if term_norm < self.series_tolerance {
285 self.series_converged = true;
286 }
287
288 Ok(())
289 }
290
291 fn apply_iteration_matrix(&mut self, matrix: &dyn Matrix) -> Result<()> {
293 let mut temp_vec = vec![0.0; self.dimension];
295
296 matrix.multiply_vector(&self.current_term, &mut temp_vec)?;
298 self.matvec_count += 1;
299
300 for (temp, &d_inv) in temp_vec.iter_mut().zip(&self.diagonal_inv) {
302 *temp *= d_inv;
303 }
304
305 for (curr, &temp) in self.current_term.iter_mut().zip(&temp_vec) {
307 *curr -= temp;
308 }
309
310 Ok(())
311 }
312
313 fn update_residual(&mut self, matrix: &dyn Matrix) -> Result<()> {
324 matrix.multiply_vector(&self.solution, &mut self.residual)?;
326 self.matvec_count += 1;
327
328 for (r, &b_val) in self.residual.iter_mut().zip(self.original_rhs.iter()) {
330 *r -= b_val;
331 }
332
333 self.residual_norm = utils::l2_norm(&self.residual);
334 Ok(())
335 }
336
337 fn estimate_error_bounds(&mut self) -> Result<()> {
339 if !self.series_converged || self.terms_computed == 0 {
340 return Ok(());
341 }
342
343 let mut matrix_norm_estimate = 0.0;
345 if self.terms_computed > 1 {
346 let term_ratio = utils::l2_norm(&self.current_term) /
347 utils::l2_norm(&self.rhs);
348 matrix_norm_estimate = term_ratio.powf(1.0 / (self.terms_computed - 1) as Precision);
349 }
350
351 if matrix_norm_estimate < 1.0 {
352 let remaining_sum_bound = matrix_norm_estimate.powi(self.terms_computed as i32) /
354 (1.0 - matrix_norm_estimate);
355 let error_bound = remaining_sum_bound * utils::l2_norm(&self.rhs);
356
357 self.error_bounds = Some(ErrorBounds::upper_bound_only(
358 error_bound,
359 ErrorBoundMethod::NeumannTruncation,
360 ));
361 }
362
363 Ok(())
364 }
365}
366
367impl SolverState for NeumannState {
368 fn residual_norm(&self) -> Precision {
369 self.residual_norm
370 }
371
372 fn matvec_count(&self) -> usize {
373 self.matvec_count
374 }
375
376 fn error_bounds(&self) -> Option<ErrorBounds> {
377 self.error_bounds.clone()
378 }
379
380 fn memory_usage(&self) -> MemoryInfo {
381 self.memory_usage.clone()
382 }
383
384 fn reset(&mut self) {
385 self.solution.fill(0.0);
386 self.residual.fill(0.0);
387 self.residual_norm = Precision::INFINITY;
388 self.current_term = self.rhs.clone();
389 self.terms_computed = 0;
390 self.matvec_count = 0;
391 self.previous_solution = None;
392 self.series_converged = false;
393 self.error_bounds = None;
394 self.matrix_powers.clear();
395 }
396}
397
398impl SolverAlgorithm for NeumannSolver {
399 type State = NeumannState;
400
401 fn initialize(
402 &self,
403 matrix: &dyn Matrix,
404 b: &[Precision],
405 options: &SolverOptions,
406 ) -> Result<Self::State> {
407 NeumannState::new(matrix, b, options, self)
408 }
409
410 fn step(&self, state: &mut Self::State) -> Result<StepResult> {
411 state.previous_solution = Some(state.solution.clone());
413
414 return Err(SolverError::AlgorithmError {
419 algorithm: "neumann".to_string(),
420 message: "Matrix reference needed for iteration - design limitation".to_string(),
421 context: vec![],
422 });
423
424 }
438
439 fn is_converged(&self, state: &Self::State) -> bool {
440 let residual_converged = state.residual_norm <= state.tolerance;
442 let series_converged = state.series_converged;
443 let max_terms_reached = state.terms_computed >= state.max_terms;
444
445 residual_converged || (series_converged && !max_terms_reached)
447 }
448
449 fn extract_solution(&self, state: &Self::State) -> Vec<Precision> {
450 state.solution.clone()
451 }
452
453 fn update_rhs(&self, state: &mut Self::State, delta_b: &[(usize, Precision)]) -> Result<()> {
454 for &(index, delta) in delta_b {
456 if index >= state.dimension {
457 return Err(SolverError::IndexOutOfBounds {
458 index,
459 max_index: state.dimension - 1,
460 context: "rhs_update".to_string(),
461 });
462 }
463
464 let scaled_delta = delta * state.diagonal_inv[index];
466 state.rhs[index] += scaled_delta;
467
468 state.solution[index] += scaled_delta;
471 }
472
473 state.current_term = state.rhs.clone();
475 state.terms_computed = 0;
476 state.series_converged = false;
477
478 Ok(())
479 }
480
481 fn algorithm_name(&self) -> &'static str {
482 "neumann"
483 }
484
485 fn solve(
487 &self,
488 matrix: &dyn Matrix,
489 b: &[Precision],
490 options: &SolverOptions,
491 ) -> Result<crate::solver::SolverResult> {
492 let mut state = self.initialize(matrix, b, options)?;
493 let mut iterations = 0;
494
495 #[cfg(feature = "std")]
496 let start_time = std::time::Instant::now();
497
498 while !self.is_converged(&state) && iterations < options.max_iterations {
499 state.previous_solution = Some(state.solution.clone());
501
502 state.compute_next_term(matrix)?;
504
505 if iterations % 5 == 0 {
507 state.update_residual(matrix)?;
508 }
509
510 if options.compute_error_bounds && self.adaptive_truncation {
512 state.estimate_error_bounds()?;
513 }
514
515 iterations += 1;
516
517 if !state.residual_norm.is_finite() {
519 return Err(SolverError::NumericalInstability {
520 reason: "Non-finite residual norm".to_string(),
521 iteration: iterations,
522 residual_norm: state.residual_norm,
523 });
524 }
525
526 if state.series_converged {
528 break;
529 }
530 }
531
532 state.update_residual(matrix)?;
534
535 let converged = self.is_converged(&state);
536 let solution = self.extract_solution(&state);
537 let residual_norm = state.residual_norm();
538
539 if !converged && iterations >= options.max_iterations {
541 return Err(SolverError::ConvergenceFailure {
542 iterations,
543 residual_norm,
544 tolerance: options.tolerance,
545 algorithm: self.algorithm_name().to_string(),
546 });
547 }
548
549 let mut result = if converged {
550 crate::solver::SolverResult::success(solution, residual_norm, iterations)
551 } else {
552 crate::solver::SolverResult::failure(solution, residual_norm, iterations)
553 };
554
555 if options.collect_stats {
557 #[cfg(feature = "std")]
558 {
559 let total_time = start_time.elapsed().as_millis() as f64;
560 let mut stats = crate::types::SolverStats::new();
561 stats.total_time_ms = total_time;
562 stats.matvec_count = state.matvec_count();
563 result.stats = Some(stats);
564 }
565 }
566
567 if options.compute_error_bounds {
568 result.error_bounds = state.error_bounds();
569 }
570
571 Ok(result)
572 }
573}
574
575#[cfg(all(test, feature = "std"))]
576mod tests {
577 use super::*;
578 use crate::matrix::SparseMatrix;
579
580 #[test]
581 fn test_neumann_solver_creation() {
582 let solver = NeumannSolver::new(16, 1e-8);
583 assert_eq!(solver.max_terms, 16);
584 assert_eq!(solver.series_tolerance, 1e-8);
585 assert!(solver.adaptive_truncation);
586 assert!(solver.cache_powers);
587
588 let fast_solver = NeumannSolver::fast();
589 assert_eq!(fast_solver.max_terms, 20);
590 assert!(!fast_solver.cache_powers);
591 }
592
593 #[test]
594 fn test_neumann_solver_simple_system() {
595 let triplets = vec![
597 (0, 0, 4.0), (0, 1, 1.0),
598 (1, 0, 1.0), (1, 1, 3.0),
599 ];
600 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
601 let b = vec![5.0, 4.0];
602
603 let solver = NeumannSolver::new(20, 1e-8);
604 let options = SolverOptions::default();
605
606 let result = solver.solve(&matrix, &b, &options);
607
608 match result {
610 Ok(solution) => {
611 assert!(solution.converged);
612 let x = solution.solution;
616 assert!((x[0] - 1.0).abs() < 0.1);
617 assert!((x[1] - 1.0).abs() < 0.1);
618 },
619 Err(e) => {
620 println!("Expected error: {:?}", e);
622 }
623 }
624 }
625
626 #[test]
627 fn test_neumann_not_diagonally_dominant() {
628 let triplets = vec![
630 (0, 0, 1.0), (0, 1, 3.0),
631 (1, 0, 2.0), (1, 1, 1.0),
632 ];
633 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
634 let b = vec![4.0, 3.0];
635
636 let solver = NeumannSolver::new(20, 1e-8);
637 let options = SolverOptions::default();
638
639 let result = solver.solve(&matrix, &b, &options);
640
641 assert!(result.is_err());
643 if let Err(SolverError::MatrixNotDiagonallyDominant { .. }) = result {
644 } else {
646 panic!("Expected MatrixNotDiagonallyDominant error");
647 }
648 }
649
650 #[test]
651 fn test_neumann_state_initialization() {
652 let triplets = vec![(0, 0, 2.0), (1, 1, 3.0)];
653 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
654 let b = vec![4.0, 6.0];
655 let solver = NeumannSolver::default();
656 let options = SolverOptions::default();
657
658 let state = solver.initialize(&matrix, &b, &options).unwrap();
659
660 assert_eq!(state.dimension, 2);
661 assert_eq!(state.diagonal_inv, vec![0.5, 1.0/3.0]);
662 assert_eq!(state.rhs, vec![2.0, 2.0]); assert_eq!(state.terms_computed, 0);
664 assert!(!state.series_converged);
665 }
666}