1use crate::error::{Result, SolverError};
7use crate::matrix::Matrix;
8use crate::solver::{utils, SolverAlgorithm, SolverOptions, SolverState, StepResult};
9use crate::types::{ErrorBoundMethod, ErrorBounds, MemoryInfo, Precision};
10use alloc::vec::Vec;
11
12#[derive(Debug, Clone)]
22pub struct NeumannSolver {
23 max_terms: usize,
25 series_tolerance: Precision,
27 adaptive_truncation: bool,
29 cache_powers: bool,
31}
32
33impl NeumannSolver {
34 pub fn new(max_terms: usize, series_tolerance: Precision) -> Self {
47 Self {
48 max_terms,
49 series_tolerance,
50 adaptive_truncation: true,
51 cache_powers: true,
52 }
53 }
54
55 pub fn default() -> Self {
57 Self::new(50, 1e-8)
58 }
59
60 pub fn high_precision() -> Self {
62 Self {
63 max_terms: 100,
64 series_tolerance: 1e-12,
65 adaptive_truncation: true,
66 cache_powers: true,
67 }
68 }
69
70 pub fn fast() -> Self {
72 Self {
73 max_terms: 20,
74 series_tolerance: 1e-6,
75 adaptive_truncation: false,
76 cache_powers: false,
77 }
78 }
79
80 pub fn with_adaptive_truncation(mut self, enable: bool) -> Self {
82 self.adaptive_truncation = enable;
83 self
84 }
85
86 pub fn with_power_caching(mut self, enable: bool) -> Self {
88 self.cache_powers = enable;
89 self
90 }
91}
92
93#[derive(Debug, Clone)]
95pub struct NeumannState {
96 dimension: usize,
98 solution: Vec<Precision>,
100 rhs: Vec<Precision>,
102 original_rhs: Vec<Precision>,
109 residual: Vec<Precision>,
111 residual_norm: Precision,
113 diagonal_inv: Vec<Precision>,
115 #[allow(dead_code)]
117 iteration_matrix: Option<Vec<Vec<Precision>>>,
118 matrix_powers: Vec<Vec<Precision>>,
120 current_term: Vec<Precision>,
122 terms_computed: usize,
124 matvec_count: usize,
126 previous_solution: Option<Vec<Precision>>,
128 series_converged: bool,
130 error_bounds: Option<ErrorBounds>,
132 memory_usage: MemoryInfo,
134 tolerance: Precision,
136 max_terms: usize,
138 series_tolerance: Precision,
140}
141
142impl NeumannState {
143 fn new(
145 matrix: &dyn Matrix,
146 b: &[Precision],
147 options: &SolverOptions,
148 solver_config: &NeumannSolver,
149 ) -> Result<Self> {
150 let dimension = matrix.rows();
151
152 if !matrix.is_square() {
153 return Err(SolverError::InvalidInput {
154 message: "Matrix must be square for Neumann series".to_string(),
155 parameter: Some("matrix_dimensions".to_string()),
156 });
157 }
158
159 if b.len() != dimension {
160 return Err(SolverError::DimensionMismatch {
161 expected: dimension,
162 actual: b.len(),
163 operation: "neumann_initialization".to_string(),
164 });
165 }
166
167 if !matrix.is_diagonally_dominant() {
169 return Err(SolverError::MatrixNotDiagonallyDominant {
170 row: 0, diagonal: 0.0,
172 off_diagonal_sum: 0.0,
173 });
174 }
175
176 let mut diagonal_inv = vec![0.0; dimension];
178 for i in 0..dimension {
179 if let Some(diag_val) = matrix.get(i, i) {
180 if diag_val.abs() < 1e-14 {
181 return Err(SolverError::InvalidSparseMatrix {
182 reason: format!("Zero or near-zero diagonal element at position {}", i),
183 position: Some((i, i)),
184 });
185 }
186 diagonal_inv[i] = 1.0 / diag_val;
187 } else {
188 return Err(SolverError::InvalidSparseMatrix {
189 reason: format!("Missing diagonal element at position {}", i),
190 position: Some((i, i)),
191 });
192 }
193 }
194
195 let rhs: Vec<Precision> = b
197 .iter()
198 .zip(&diagonal_inv)
199 .map(|(&b_val, &d_inv)| b_val * d_inv)
200 .collect();
201
202 let solution = if let Some(ref initial) = options.initial_guess {
208 if initial.len() != dimension {
209 return Err(SolverError::DimensionMismatch {
210 expected: dimension,
211 actual: initial.len(),
212 operation: "initial_guess".to_string(),
213 });
214 }
215 initial.clone()
216 } else {
217 vec![0.0; dimension]
218 };
219
220 let residual = vec![0.0; dimension];
221 let current_term = rhs.clone();
222
223 let matrix_powers = if solver_config.cache_powers {
224 Vec::with_capacity(solver_config.max_terms)
225 } else {
226 Vec::new()
227 };
228
229 let memory_usage = MemoryInfo {
230 current_usage_bytes: dimension * 8 * 5, peak_usage_bytes: dimension * 8 * 5,
232 matrix_memory_bytes: 0, vector_memory_bytes: dimension * 8 * 5,
234 workspace_memory_bytes: 0,
235 allocation_count: 5,
236 deallocation_count: 0,
237 };
238
239 Ok(Self {
240 dimension,
241 solution,
242 rhs,
243 original_rhs: b.to_vec(),
244 residual,
245 residual_norm: Precision::INFINITY,
246 diagonal_inv,
247 iteration_matrix: None,
248 matrix_powers,
249 current_term,
250 terms_computed: 0,
251 matvec_count: 0,
252 previous_solution: None,
253 series_converged: false,
254 error_bounds: None,
255 memory_usage,
256 tolerance: options.tolerance,
257 max_terms: solver_config.max_terms,
258 series_tolerance: solver_config.series_tolerance,
259 })
260 }
261
262 fn compute_next_term(&mut self, matrix: &dyn Matrix) -> Result<()> {
264 if self.terms_computed >= self.max_terms {
265 return Ok(());
266 }
267
268 if self.terms_computed > 0 {
271 self.apply_iteration_matrix(matrix)?;
272 }
273
274 for (sol, &term) in self.solution.iter_mut().zip(&self.current_term) {
276 *sol += term;
277 }
278
279 self.terms_computed += 1;
280
281 let term_norm = utils::l2_norm(&self.current_term);
283 if term_norm < self.series_tolerance {
284 self.series_converged = true;
285 }
286
287 Ok(())
288 }
289
290 fn apply_iteration_matrix(&mut self, matrix: &dyn Matrix) -> Result<()> {
292 let mut temp_vec = vec![0.0; self.dimension];
294
295 matrix.multiply_vector(&self.current_term, &mut temp_vec)?;
297 self.matvec_count += 1;
298
299 for (temp, &d_inv) in temp_vec.iter_mut().zip(&self.diagonal_inv) {
301 *temp *= d_inv;
302 }
303
304 for (curr, &temp) in self.current_term.iter_mut().zip(&temp_vec) {
306 *curr -= temp;
307 }
308
309 Ok(())
310 }
311
312 fn update_residual(&mut self, matrix: &dyn Matrix) -> Result<()> {
323 matrix.multiply_vector(&self.solution, &mut self.residual)?;
325 self.matvec_count += 1;
326
327 for (r, &b_val) in self.residual.iter_mut().zip(self.original_rhs.iter()) {
329 *r -= b_val;
330 }
331
332 self.residual_norm = utils::l2_norm(&self.residual);
333 Ok(())
334 }
335
336 fn estimate_error_bounds(&mut self) -> Result<()> {
338 if !self.series_converged || self.terms_computed == 0 {
339 return Ok(());
340 }
341
342 let mut matrix_norm_estimate = 0.0;
344 if self.terms_computed > 1 {
345 let term_ratio = utils::l2_norm(&self.current_term) / utils::l2_norm(&self.rhs);
346 matrix_norm_estimate = term_ratio.powf(1.0 / (self.terms_computed - 1) as Precision);
347 }
348
349 if matrix_norm_estimate < 1.0 {
350 let remaining_sum_bound = matrix_norm_estimate.powi(self.terms_computed as i32)
352 / (1.0 - matrix_norm_estimate);
353 let error_bound = remaining_sum_bound * utils::l2_norm(&self.rhs);
354
355 self.error_bounds = Some(ErrorBounds::upper_bound_only(
356 error_bound,
357 ErrorBoundMethod::NeumannTruncation,
358 ));
359 }
360
361 Ok(())
362 }
363}
364
365impl SolverState for NeumannState {
366 fn residual_norm(&self) -> Precision {
367 self.residual_norm
368 }
369
370 fn matvec_count(&self) -> usize {
371 self.matvec_count
372 }
373
374 fn error_bounds(&self) -> Option<ErrorBounds> {
375 self.error_bounds.clone()
376 }
377
378 fn memory_usage(&self) -> MemoryInfo {
379 self.memory_usage.clone()
380 }
381
382 fn reset(&mut self) {
383 self.solution.fill(0.0);
384 self.residual.fill(0.0);
385 self.residual_norm = Precision::INFINITY;
386 self.current_term = self.rhs.clone();
387 self.terms_computed = 0;
388 self.matvec_count = 0;
389 self.previous_solution = None;
390 self.series_converged = false;
391 self.error_bounds = None;
392 self.matrix_powers.clear();
393 }
394}
395
396impl SolverAlgorithm for NeumannSolver {
397 type State = NeumannState;
398
399 fn initialize(
400 &self,
401 matrix: &dyn Matrix,
402 b: &[Precision],
403 options: &SolverOptions,
404 ) -> Result<Self::State> {
405 NeumannState::new(matrix, b, options, self)
406 }
407
408 fn step(&self, state: &mut Self::State) -> Result<StepResult> {
409 state.previous_solution = Some(state.solution.clone());
411
412 return Err(SolverError::AlgorithmError {
417 algorithm: "neumann".to_string(),
418 message: "Matrix reference needed for iteration - design limitation".to_string(),
419 context: vec![],
420 });
421
422 }
436
437 fn is_converged(&self, state: &Self::State) -> bool {
438 let residual_converged = state.residual_norm <= state.tolerance;
440 let series_converged = state.series_converged;
441 let max_terms_reached = state.terms_computed >= state.max_terms;
442
443 residual_converged || (series_converged && !max_terms_reached)
445 }
446
447 fn extract_solution(&self, state: &Self::State) -> Vec<Precision> {
448 state.solution.clone()
449 }
450
451 fn update_rhs(&self, state: &mut Self::State, delta_b: &[(usize, Precision)]) -> Result<()> {
452 for &(index, delta) in delta_b {
454 if index >= state.dimension {
455 return Err(SolverError::IndexOutOfBounds {
456 index,
457 max_index: state.dimension - 1,
458 context: "rhs_update".to_string(),
459 });
460 }
461
462 let scaled_delta = delta * state.diagonal_inv[index];
464 state.rhs[index] += scaled_delta;
465
466 state.solution[index] += scaled_delta;
469 }
470
471 state.current_term = state.rhs.clone();
473 state.terms_computed = 0;
474 state.series_converged = false;
475
476 Ok(())
477 }
478
479 fn algorithm_name(&self) -> &'static str {
480 "neumann"
481 }
482
483 fn solve(
485 &self,
486 matrix: &dyn Matrix,
487 b: &[Precision],
488 options: &SolverOptions,
489 ) -> Result<crate::solver::SolverResult> {
490 let mut state = self.initialize(matrix, b, options)?;
491 let mut iterations = 0;
492
493 #[cfg(feature = "std")]
494 let start_time = std::time::Instant::now();
495
496 while !self.is_converged(&state) && iterations < options.max_iterations {
497 state.previous_solution = Some(state.solution.clone());
499
500 state.compute_next_term(matrix)?;
502
503 if iterations % 5 == 0 {
505 state.update_residual(matrix)?;
506 }
507
508 if options.compute_error_bounds && self.adaptive_truncation {
510 state.estimate_error_bounds()?;
511 }
512
513 iterations += 1;
514
515 if !state.residual_norm.is_finite() {
517 return Err(SolverError::NumericalInstability {
518 reason: "Non-finite residual norm".to_string(),
519 iteration: iterations,
520 residual_norm: state.residual_norm,
521 });
522 }
523
524 if state.series_converged {
526 break;
527 }
528 }
529
530 state.update_residual(matrix)?;
532
533 let converged = self.is_converged(&state);
534 let solution = self.extract_solution(&state);
535 let residual_norm = state.residual_norm();
536
537 if !converged && iterations >= options.max_iterations {
539 return Err(SolverError::ConvergenceFailure {
540 iterations,
541 residual_norm,
542 tolerance: options.tolerance,
543 algorithm: self.algorithm_name().to_string(),
544 });
545 }
546
547 let mut result = if converged {
548 crate::solver::SolverResult::success(solution, residual_norm, iterations)
549 } else {
550 crate::solver::SolverResult::failure(solution, residual_norm, iterations)
551 };
552
553 if options.collect_stats {
555 #[cfg(feature = "std")]
556 {
557 let total_time = start_time.elapsed().as_millis() as f64;
558 let mut stats = crate::types::SolverStats::new();
559 stats.total_time_ms = total_time;
560 stats.matvec_count = state.matvec_count();
561 result.stats = Some(stats);
562 }
563 }
564
565 if options.compute_error_bounds {
566 result.error_bounds = state.error_bounds();
567 }
568
569 Ok(result)
570 }
571}
572
573#[cfg(all(test, feature = "std"))]
574mod tests {
575 use super::*;
576 use crate::matrix::SparseMatrix;
577
578 #[test]
579 fn test_neumann_solver_creation() {
580 let solver = NeumannSolver::new(16, 1e-8);
581 assert_eq!(solver.max_terms, 16);
582 assert_eq!(solver.series_tolerance, 1e-8);
583 assert!(solver.adaptive_truncation);
584 assert!(solver.cache_powers);
585
586 let fast_solver = NeumannSolver::fast();
587 assert_eq!(fast_solver.max_terms, 20);
588 assert!(!fast_solver.cache_powers);
589 }
590
591 #[test]
592 fn test_neumann_solver_simple_system() {
593 let triplets = vec![(0, 0, 4.0), (0, 1, 1.0), (1, 0, 1.0), (1, 1, 3.0)];
595 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
596 let b = vec![5.0, 4.0];
597
598 let solver = NeumannSolver::new(20, 1e-8);
599 let options = SolverOptions::default();
600
601 let result = solver.solve(&matrix, &b, &options);
602
603 match result {
605 Ok(solution) => {
606 assert!(solution.converged);
607 let x = solution.solution;
611 assert!((x[0] - 1.0).abs() < 0.1);
612 assert!((x[1] - 1.0).abs() < 0.1);
613 }
614 Err(e) => {
615 println!("Expected error: {:?}", e);
617 }
618 }
619 }
620
621 #[test]
622 fn test_neumann_not_diagonally_dominant() {
623 let triplets = vec![(0, 0, 1.0), (0, 1, 3.0), (1, 0, 2.0), (1, 1, 1.0)];
625 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
626 let b = vec![4.0, 3.0];
627
628 let solver = NeumannSolver::new(20, 1e-8);
629 let options = SolverOptions::default();
630
631 let result = solver.solve(&matrix, &b, &options);
632
633 assert!(result.is_err());
635 if let Err(SolverError::MatrixNotDiagonallyDominant { .. }) = result {
636 } else {
638 panic!("Expected MatrixNotDiagonallyDominant error");
639 }
640 }
641
642 #[test]
643 fn test_neumann_state_initialization() {
644 let triplets = vec![(0, 0, 2.0), (1, 1, 3.0)];
645 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
646 let b = vec![4.0, 6.0];
647 let solver = NeumannSolver::default();
648 let options = SolverOptions::default();
649
650 let state = solver.initialize(&matrix, &b, &options).unwrap();
651
652 assert_eq!(state.dimension, 2);
653 assert_eq!(state.diagonal_inv, vec![0.5, 1.0 / 3.0]);
654 assert_eq!(state.rhs, vec![2.0, 2.0]); assert_eq!(state.terms_computed, 0);
656 assert!(!state.series_converged);
657 }
658}