1use crate::matrix::Matrix;
7use crate::types::{Precision, ErrorBounds, ErrorBoundMethod, MemoryInfo};
8use crate::error::{SolverError, Result};
9use crate::solver::{
10 SolverAlgorithm, SolverState, SolverOptions, StepResult, utils
11};
12use alloc::{vec::Vec, string::String};
13
14#[derive(Debug, Clone)]
24pub struct NeumannSolver {
25 max_terms: usize,
27 series_tolerance: Precision,
29 adaptive_truncation: bool,
31 cache_powers: bool,
33}
34
35impl NeumannSolver {
36 pub fn new(max_terms: usize, series_tolerance: Precision) -> Self {
49 Self {
50 max_terms,
51 series_tolerance,
52 adaptive_truncation: true,
53 cache_powers: true,
54 }
55 }
56
57 pub fn default() -> Self {
59 Self::new(50, 1e-8)
60 }
61
62 pub fn high_precision() -> Self {
64 Self {
65 max_terms: 100,
66 series_tolerance: 1e-12,
67 adaptive_truncation: true,
68 cache_powers: true,
69 }
70 }
71
72 pub fn fast() -> Self {
74 Self {
75 max_terms: 20,
76 series_tolerance: 1e-6,
77 adaptive_truncation: false,
78 cache_powers: false,
79 }
80 }
81
82 pub fn with_adaptive_truncation(mut self, enable: bool) -> Self {
84 self.adaptive_truncation = enable;
85 self
86 }
87
88 pub fn with_power_caching(mut self, enable: bool) -> Self {
90 self.cache_powers = enable;
91 self
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct NeumannState {
98 dimension: usize,
100 solution: Vec<Precision>,
102 rhs: Vec<Precision>,
104 residual: Vec<Precision>,
106 residual_norm: Precision,
108 diagonal_inv: Vec<Precision>,
110 #[allow(dead_code)]
112 iteration_matrix: Option<Vec<Vec<Precision>>>,
113 matrix_powers: Vec<Vec<Precision>>,
115 current_term: Vec<Precision>,
117 terms_computed: usize,
119 matvec_count: usize,
121 previous_solution: Option<Vec<Precision>>,
123 series_converged: bool,
125 error_bounds: Option<ErrorBounds>,
127 memory_usage: MemoryInfo,
129 tolerance: Precision,
131 max_terms: usize,
133 series_tolerance: Precision,
135}
136
137impl NeumannState {
138 fn new(
140 matrix: &dyn Matrix,
141 b: &[Precision],
142 options: &SolverOptions,
143 solver_config: &NeumannSolver,
144 ) -> Result<Self> {
145 let dimension = matrix.rows();
146
147 if !matrix.is_square() {
148 return Err(SolverError::InvalidInput {
149 message: "Matrix must be square for Neumann series".to_string(),
150 parameter: Some("matrix_dimensions".to_string()),
151 });
152 }
153
154 if b.len() != dimension {
155 return Err(SolverError::DimensionMismatch {
156 expected: dimension,
157 actual: b.len(),
158 operation: "neumann_initialization".to_string(),
159 });
160 }
161
162 if !matrix.is_diagonally_dominant() {
164 return Err(SolverError::MatrixNotDiagonallyDominant {
165 row: 0, diagonal: 0.0,
167 off_diagonal_sum: 0.0,
168 });
169 }
170
171 let mut diagonal_inv = vec![0.0; dimension];
173 for i in 0..dimension {
174 if let Some(diag_val) = matrix.get(i, i) {
175 if diag_val.abs() < 1e-14 {
176 return Err(SolverError::InvalidSparseMatrix {
177 reason: format!("Zero or near-zero diagonal element at position {}", i),
178 position: Some((i, i)),
179 });
180 }
181 diagonal_inv[i] = 1.0 / diag_val;
182 } else {
183 return Err(SolverError::InvalidSparseMatrix {
184 reason: format!("Missing diagonal element at position {}", i),
185 position: Some((i, i)),
186 });
187 }
188 }
189
190 let rhs: Vec<Precision> = b.iter()
192 .zip(&diagonal_inv)
193 .map(|(&b_val, &d_inv)| b_val * d_inv)
194 .collect();
195
196 let solution = if let Some(ref initial) = options.initial_guess {
198 if initial.len() != dimension {
199 return Err(SolverError::DimensionMismatch {
200 expected: dimension,
201 actual: initial.len(),
202 operation: "initial_guess".to_string(),
203 });
204 }
205 initial.clone()
206 } else {
207 rhs.clone() };
209
210 let residual = vec![0.0; dimension];
211 let current_term = rhs.clone();
212
213 let matrix_powers = if solver_config.cache_powers {
214 Vec::with_capacity(solver_config.max_terms)
215 } else {
216 Vec::new()
217 };
218
219 let memory_usage = MemoryInfo {
220 current_usage_bytes: dimension * 8 * 5, peak_usage_bytes: dimension * 8 * 5,
222 matrix_memory_bytes: 0, vector_memory_bytes: dimension * 8 * 5,
224 workspace_memory_bytes: 0,
225 allocation_count: 5,
226 deallocation_count: 0,
227 };
228
229 Ok(Self {
230 dimension,
231 solution,
232 rhs,
233 residual,
234 residual_norm: Precision::INFINITY,
235 diagonal_inv,
236 iteration_matrix: None,
237 matrix_powers,
238 current_term,
239 terms_computed: 0,
240 matvec_count: 0,
241 previous_solution: None,
242 series_converged: false,
243 error_bounds: None,
244 memory_usage,
245 tolerance: options.tolerance,
246 max_terms: solver_config.max_terms,
247 series_tolerance: solver_config.series_tolerance,
248 })
249 }
250
251 fn compute_next_term(&mut self, matrix: &dyn Matrix) -> Result<()> {
253 if self.terms_computed >= self.max_terms {
254 return Ok(());
255 }
256
257 if self.terms_computed > 0 {
260 self.apply_iteration_matrix(matrix)?;
261 }
262
263 for (sol, &term) in self.solution.iter_mut().zip(&self.current_term) {
265 *sol += term;
266 }
267
268 self.terms_computed += 1;
269
270 let term_norm = utils::l2_norm(&self.current_term);
272 if term_norm < self.series_tolerance {
273 self.series_converged = true;
274 }
275
276 Ok(())
277 }
278
279 fn apply_iteration_matrix(&mut self, matrix: &dyn Matrix) -> Result<()> {
281 let mut temp_vec = vec![0.0; self.dimension];
283
284 matrix.multiply_vector(&self.current_term, &mut temp_vec)?;
286 self.matvec_count += 1;
287
288 for (temp, &d_inv) in temp_vec.iter_mut().zip(&self.diagonal_inv) {
290 *temp *= d_inv;
291 }
292
293 for (curr, &temp) in self.current_term.iter_mut().zip(&temp_vec) {
295 *curr -= temp;
296 }
297
298 Ok(())
299 }
300
301 fn update_residual(&mut self, matrix: &dyn Matrix) -> Result<()> {
303 matrix.multiply_vector(&self.solution, &mut self.residual)?;
305 self.matvec_count += 1;
306
307 for (r, &b_val) in self.residual.iter_mut().zip(self.rhs.iter()) {
309 *r = *r - b_val; }
311
312 self.residual_norm = utils::l2_norm(&self.residual);
317 Ok(())
318 }
319
320 fn estimate_error_bounds(&mut self) -> Result<()> {
322 if !self.series_converged || self.terms_computed == 0 {
323 return Ok(());
324 }
325
326 let mut matrix_norm_estimate = 0.0;
328 if self.terms_computed > 1 {
329 let term_ratio = utils::l2_norm(&self.current_term) /
330 utils::l2_norm(&self.rhs);
331 matrix_norm_estimate = term_ratio.powf(1.0 / (self.terms_computed - 1) as Precision);
332 }
333
334 if matrix_norm_estimate < 1.0 {
335 let remaining_sum_bound = matrix_norm_estimate.powi(self.terms_computed as i32) /
337 (1.0 - matrix_norm_estimate);
338 let error_bound = remaining_sum_bound * utils::l2_norm(&self.rhs);
339
340 self.error_bounds = Some(ErrorBounds::upper_bound_only(
341 error_bound,
342 ErrorBoundMethod::NeumannTruncation,
343 ));
344 }
345
346 Ok(())
347 }
348}
349
350impl SolverState for NeumannState {
351 fn residual_norm(&self) -> Precision {
352 self.residual_norm
353 }
354
355 fn matvec_count(&self) -> usize {
356 self.matvec_count
357 }
358
359 fn error_bounds(&self) -> Option<ErrorBounds> {
360 self.error_bounds.clone()
361 }
362
363 fn memory_usage(&self) -> MemoryInfo {
364 self.memory_usage.clone()
365 }
366
367 fn reset(&mut self) {
368 self.solution.fill(0.0);
369 self.residual.fill(0.0);
370 self.residual_norm = Precision::INFINITY;
371 self.current_term = self.rhs.clone();
372 self.terms_computed = 0;
373 self.matvec_count = 0;
374 self.previous_solution = None;
375 self.series_converged = false;
376 self.error_bounds = None;
377 self.matrix_powers.clear();
378 }
379}
380
381impl SolverAlgorithm for NeumannSolver {
382 type State = NeumannState;
383
384 fn initialize(
385 &self,
386 matrix: &dyn Matrix,
387 b: &[Precision],
388 options: &SolverOptions,
389 ) -> Result<Self::State> {
390 NeumannState::new(matrix, b, options, self)
391 }
392
393 fn step(&self, state: &mut Self::State) -> Result<StepResult> {
394 state.previous_solution = Some(state.solution.clone());
396
397 return Err(SolverError::AlgorithmError {
402 algorithm: "neumann".to_string(),
403 message: "Matrix reference needed for iteration - design limitation".to_string(),
404 context: vec![],
405 });
406
407 }
421
422 fn is_converged(&self, state: &Self::State) -> bool {
423 let residual_converged = state.residual_norm <= state.tolerance;
425 let series_converged = state.series_converged;
426 let max_terms_reached = state.terms_computed >= state.max_terms;
427
428 residual_converged || (series_converged && !max_terms_reached)
430 }
431
432 fn extract_solution(&self, state: &Self::State) -> Vec<Precision> {
433 state.solution.clone()
434 }
435
436 fn update_rhs(&self, state: &mut Self::State, delta_b: &[(usize, Precision)]) -> Result<()> {
437 for &(index, delta) in delta_b {
439 if index >= state.dimension {
440 return Err(SolverError::IndexOutOfBounds {
441 index,
442 max_index: state.dimension - 1,
443 context: "rhs_update".to_string(),
444 });
445 }
446
447 let scaled_delta = delta * state.diagonal_inv[index];
449 state.rhs[index] += scaled_delta;
450
451 state.solution[index] += scaled_delta;
454 }
455
456 state.current_term = state.rhs.clone();
458 state.terms_computed = 0;
459 state.series_converged = false;
460
461 Ok(())
462 }
463
464 fn algorithm_name(&self) -> &'static str {
465 "neumann"
466 }
467
468 fn solve(
470 &self,
471 matrix: &dyn Matrix,
472 b: &[Precision],
473 options: &SolverOptions,
474 ) -> Result<crate::solver::SolverResult> {
475 let mut state = self.initialize(matrix, b, options)?;
476 let mut iterations = 0;
477
478 #[cfg(feature = "std")]
479 let start_time = std::time::Instant::now();
480
481 while !self.is_converged(&state) && iterations < options.max_iterations {
482 state.previous_solution = Some(state.solution.clone());
484
485 state.compute_next_term(matrix)?;
487
488 if iterations % 5 == 0 {
490 state.update_residual(matrix)?;
491 }
492
493 if options.compute_error_bounds && self.adaptive_truncation {
495 state.estimate_error_bounds()?;
496 }
497
498 iterations += 1;
499
500 if !state.residual_norm.is_finite() {
502 return Err(SolverError::NumericalInstability {
503 reason: "Non-finite residual norm".to_string(),
504 iteration: iterations,
505 residual_norm: state.residual_norm,
506 });
507 }
508
509 if state.series_converged {
511 break;
512 }
513 }
514
515 state.update_residual(matrix)?;
517
518 let converged = self.is_converged(&state);
519 let solution = self.extract_solution(&state);
520 let residual_norm = state.residual_norm();
521
522 if !converged && iterations >= options.max_iterations {
524 return Err(SolverError::ConvergenceFailure {
525 iterations,
526 residual_norm,
527 tolerance: options.tolerance,
528 algorithm: self.algorithm_name().to_string(),
529 });
530 }
531
532 let mut result = if converged {
533 crate::solver::SolverResult::success(solution, residual_norm, iterations)
534 } else {
535 crate::solver::SolverResult::failure(solution, residual_norm, iterations)
536 };
537
538 if options.collect_stats {
540 #[cfg(feature = "std")]
541 {
542 let total_time = start_time.elapsed().as_millis() as f64;
543 let mut stats = crate::types::SolverStats::new();
544 stats.total_time_ms = total_time;
545 stats.matvec_count = state.matvec_count();
546 result.stats = Some(stats);
547 }
548 }
549
550 if options.compute_error_bounds {
551 result.error_bounds = state.error_bounds();
552 }
553
554 Ok(result)
555 }
556}
557
558#[cfg(all(test, feature = "std"))]
559mod tests {
560 use super::*;
561 use crate::matrix::SparseMatrix;
562
563 #[test]
564 fn test_neumann_solver_creation() {
565 let solver = NeumannSolver::new(16, 1e-8);
566 assert_eq!(solver.max_terms, 16);
567 assert_eq!(solver.series_tolerance, 1e-8);
568 assert!(solver.adaptive_truncation);
569 assert!(solver.cache_powers);
570
571 let fast_solver = NeumannSolver::fast();
572 assert_eq!(fast_solver.max_terms, 20);
573 assert!(!fast_solver.cache_powers);
574 }
575
576 #[test]
577 fn test_neumann_solver_simple_system() {
578 let triplets = vec![
580 (0, 0, 4.0), (0, 1, 1.0),
581 (1, 0, 1.0), (1, 1, 3.0),
582 ];
583 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
584 let b = vec![5.0, 4.0];
585
586 let solver = NeumannSolver::new(20, 1e-8);
587 let options = SolverOptions::default();
588
589 let result = solver.solve(&matrix, &b, &options);
590
591 match result {
593 Ok(solution) => {
594 assert!(solution.converged);
595 let x = solution.solution;
599 assert!((x[0] - 1.0).abs() < 0.1);
600 assert!((x[1] - 1.0).abs() < 0.1);
601 },
602 Err(e) => {
603 println!("Expected error: {:?}", e);
605 }
606 }
607 }
608
609 #[test]
610 fn test_neumann_not_diagonally_dominant() {
611 let triplets = vec![
613 (0, 0, 1.0), (0, 1, 3.0),
614 (1, 0, 2.0), (1, 1, 1.0),
615 ];
616 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
617 let b = vec![4.0, 3.0];
618
619 let solver = NeumannSolver::new(20, 1e-8);
620 let options = SolverOptions::default();
621
622 let result = solver.solve(&matrix, &b, &options);
623
624 assert!(result.is_err());
626 if let Err(SolverError::MatrixNotDiagonallyDominant { .. }) = result {
627 } else {
629 panic!("Expected MatrixNotDiagonallyDominant error");
630 }
631 }
632
633 #[test]
634 fn test_neumann_state_initialization() {
635 let triplets = vec![(0, 0, 2.0), (1, 1, 3.0)];
636 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
637 let b = vec![4.0, 6.0];
638 let solver = NeumannSolver::default();
639 let options = SolverOptions::default();
640
641 let state = solver.initialize(&matrix, &b, &options).unwrap();
642
643 assert_eq!(state.dimension, 2);
644 assert_eq!(state.diagonal_inv, vec![0.5, 1.0/3.0]);
645 assert_eq!(state.rhs, vec![2.0, 2.0]); assert_eq!(state.terms_computed, 0);
647 assert!(!state.series_converged);
648 }
649}