1use crate::csr_array::CsrArray;
8use crate::error::{SparseError, SparseResult};
9use crate::sparray::SparseArray;
10use scirs2_core::ndarray::{Array1, ArrayView1};
11use scirs2_core::numeric::Float;
12use scirs2_core::SparseElement;
13use std::collections::HashMap;
14use std::fmt::Debug;
15
16#[derive(Debug, Clone)]
18pub struct AMGOptions {
19 pub max_levels: usize,
21 pub theta: f64,
23 pub max_coarse_size: usize,
25 pub interpolation: InterpolationType,
27 pub smoother: SmootherType,
29 pub pre_smooth_steps: usize,
31 pub post_smooth_steps: usize,
33 pub cycle_type: CycleType,
35}
36
37impl Default for AMGOptions {
38 fn default() -> Self {
39 Self {
40 max_levels: 10,
41 theta: 0.25,
42 max_coarse_size: 50,
43 interpolation: InterpolationType::Classical,
44 smoother: SmootherType::GaussSeidel,
45 pre_smooth_steps: 1,
46 post_smooth_steps: 1,
47 cycle_type: CycleType::V,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Copy)]
54pub enum InterpolationType {
55 Classical,
57 Direct,
59 Standard,
61}
62
63#[derive(Debug, Clone, Copy)]
65pub enum SmootherType {
66 GaussSeidel,
68 Jacobi,
70 SOR,
72}
73
74#[derive(Debug, Clone, Copy)]
76pub enum CycleType {
77 V,
79 W,
81 F,
83}
84
85#[derive(Debug)]
87pub struct AMGPreconditioner<T>
88where
89 T: Float + SparseElement + Debug + Copy + 'static,
90{
91 operators: Vec<CsrArray<T>>,
93 prolongations: Vec<CsrArray<T>>,
95 restrictions: Vec<CsrArray<T>>,
97 options: AMGOptions,
99 num_levels: usize,
101}
102
103impl<T> AMGPreconditioner<T>
104where
105 T: Float + SparseElement + Debug + Copy + 'static,
106{
107 pub fn new(matrix: &CsrArray<T>, options: AMGOptions) -> SparseResult<Self> {
134 let mut amg = AMGPreconditioner {
135 operators: vec![matrix.clone()],
136 prolongations: Vec::new(),
137 restrictions: Vec::new(),
138 options,
139 num_levels: 1,
140 };
141
142 amg.build_hierarchy()?;
144
145 Ok(amg)
146 }
147
148 fn build_hierarchy(&mut self) -> SparseResult<()> {
150 let mut level = 0;
151
152 while level < self.options.max_levels - 1 {
153 let currentmatrix = &self.operators[level];
154 let (rows, _) = currentmatrix.shape();
155
156 if rows <= self.options.max_coarse_size {
158 break;
159 }
160
161 let (coarsematrix, prolongation, restriction) = self.coarsen_level(currentmatrix)?;
163
164 let (coarse_rows, _) = coarsematrix.shape();
166 if coarse_rows >= rows {
167 break;
169 }
170
171 self.operators.push(coarsematrix);
172 self.prolongations.push(prolongation);
173 self.restrictions.push(restriction);
174 self.num_levels += 1;
175 level += 1;
176 }
177
178 Ok(())
179 }
180
181 fn coarsen_level(
183 &self,
184 matrix: &CsrArray<T>,
185 ) -> SparseResult<(CsrArray<T>, CsrArray<T>, CsrArray<T>)> {
186 let (n, _) = matrix.shape();
187
188 let strong_connections = self.detect_strong_connections(matrix)?;
190
191 let (c_points, f_points) = self.classical_cf_splitting(matrix, &strong_connections)?;
193
194 let mut fine_to_coarse = HashMap::new();
196 for (coarse_idx, &fine_idx) in c_points.iter().enumerate() {
197 fine_to_coarse.insert(fine_idx, coarse_idx);
198 }
199
200 let coarse_size = c_points.len();
201
202 let prolongation = self.build_prolongation(matrix, &fine_to_coarse, coarse_size)?;
204
205 let restriction_box = prolongation.transpose()?;
207 let restriction = restriction_box
208 .as_any()
209 .downcast_ref::<CsrArray<T>>()
210 .ok_or_else(|| {
211 SparseError::ValueError("Failed to downcast restriction to CsrArray".to_string())
212 })?
213 .clone();
214
215 let temp_box = restriction.dot(matrix)?;
217 let temp = temp_box
218 .as_any()
219 .downcast_ref::<CsrArray<T>>()
220 .ok_or_else(|| {
221 SparseError::ValueError("Failed to downcast temp to CsrArray".to_string())
222 })?;
223 let coarsematrix_box = temp.dot(&prolongation)?;
224 let coarsematrix = coarsematrix_box
225 .as_any()
226 .downcast_ref::<CsrArray<T>>()
227 .ok_or_else(|| {
228 SparseError::ValueError("Failed to downcast coarsematrix to CsrArray".to_string())
229 })?
230 .clone();
231
232 Ok((coarsematrix, prolongation, restriction))
233 }
234
235 fn detect_strong_connections(&self, matrix: &CsrArray<T>) -> SparseResult<Vec<Vec<usize>>> {
238 let (n, _) = matrix.shape();
239 let mut strong_connections = vec![Vec::new(); n];
240
241 #[allow(clippy::needless_range_loop)]
242 for i in 0..n {
243 let row_start = matrix.get_indptr()[i];
244 let row_end = matrix.get_indptr()[i + 1];
245
246 let mut max_off_diag = T::sparse_zero();
248 for j in row_start..row_end {
249 let col = matrix.get_indices()[j];
250 if col != i {
251 let val = matrix.get_data()[j].abs();
252 if val > max_off_diag {
253 max_off_diag = val;
254 }
255 }
256 }
257
258 let threshold = T::from(self.options.theta).unwrap() * max_off_diag;
260 for j in row_start..row_end {
261 let col = matrix.get_indices()[j];
262 if col != i {
263 let val = matrix.get_data()[j].abs();
264 if val >= threshold {
265 strong_connections[i].push(col);
266 }
267 }
268 }
269 }
270
271 Ok(strong_connections)
272 }
273
274 fn classical_cf_splitting(
276 &self,
277 matrix: &CsrArray<T>,
278 strong_connections: &[Vec<usize>],
279 ) -> SparseResult<(Vec<usize>, Vec<usize>)> {
280 let (n, _) = matrix.shape();
281
282 let mut influence = vec![0; n];
284 for i in 0..n {
285 influence[i] = strong_connections[i].len();
286 }
287
288 let mut point_type = vec![0; n];
290 let mut c_points = Vec::new();
291 let mut f_points = Vec::new();
292
293 let mut sorted_points: Vec<usize> = (0..n).collect();
295 sorted_points.sort_by(|&a, &b| influence[b].cmp(&influence[a]));
296
297 for &i in &sorted_points {
298 if point_type[i] != 0 {
299 continue; }
301
302 let mut needs_coarse = false;
304
305 for &j in &strong_connections[i] {
307 if point_type[j] == 2 {
308 let mut has_coarse_interp = false;
311 for &k in &strong_connections[j] {
312 if point_type[k] == 1 {
313 has_coarse_interp = true;
315 break;
316 }
317 }
318 if !has_coarse_interp {
319 needs_coarse = true;
320 break;
321 }
322 }
323 }
324
325 if needs_coarse || influence[i] > 2 {
326 point_type[i] = 1;
328 c_points.push(i);
329
330 for &j in &strong_connections[i] {
332 if point_type[j] == 0 {
333 point_type[j] = 2;
334 f_points.push(j);
335 }
336 }
337 }
338 }
339
340 #[allow(clippy::needless_range_loop)]
342 for i in 0..n {
343 if point_type[i] == 0 {
344 point_type[i] = 2;
345 f_points.push(i);
346 }
347 }
348
349 Ok((c_points, f_points))
350 }
351
352 fn build_prolongation(
354 &self,
355 matrix: &CsrArray<T>,
356 fine_to_coarse: &HashMap<usize, usize>,
357 coarse_size: usize,
358 ) -> SparseResult<CsrArray<T>> {
359 let (n, _) = matrix.shape();
360 let mut prolongation_data = Vec::new();
361 let mut prolongation_indices = Vec::new();
362 let mut prolongation_indptr = vec![0];
363
364 let strong_connections = self.detect_strong_connections(matrix)?;
366
367 #[allow(clippy::needless_range_loop)]
368 for i in 0..n {
369 if let Some(&coarse_idx) = fine_to_coarse.get(&i) {
370 prolongation_data.push(T::sparse_one());
372 prolongation_indices.push(coarse_idx);
373 } else {
374 let interp_weights = self.compute_interpolation_weights(
376 i,
377 matrix,
378 &strong_connections[i],
379 fine_to_coarse,
380 )?;
381
382 if interp_weights.is_empty() {
383 prolongation_data.push(T::sparse_one());
385 prolongation_indices.push(0);
386 } else {
387 for (coarse_idx, weight) in interp_weights {
389 prolongation_data.push(weight);
390 prolongation_indices.push(coarse_idx);
391 }
392 }
393 }
394 prolongation_indptr.push(prolongation_data.len());
395 }
396
397 CsrArray::new(
398 prolongation_data.into(),
399 prolongation_indptr.into(),
400 prolongation_indices.into(),
401 (n, coarse_size),
402 )
403 }
404
405 fn compute_interpolation_weights(
407 &self,
408 fine_point: usize,
409 matrix: &CsrArray<T>,
410 strong_neighbors: &[usize],
411 fine_to_coarse: &HashMap<usize, usize>,
412 ) -> SparseResult<Vec<(usize, T)>> {
413 let mut weights = Vec::new();
414
415 let mut coarse_neighbors = Vec::new();
417 let mut coarse_weights = Vec::new();
418
419 for &neighbor in strong_neighbors {
420 if let Some(&coarse_idx) = fine_to_coarse.get(&neighbor) {
421 coarse_neighbors.push(neighbor);
422 coarse_weights.push(coarse_idx);
423 }
424 }
425
426 if coarse_neighbors.is_empty() {
427 return Ok(weights);
428 }
429
430 let mut a_ii = T::sparse_zero();
432 let row_start = matrix.get_indptr()[fine_point];
433 let row_end = matrix.get_indptr()[fine_point + 1];
434
435 for j in row_start..row_end {
436 let col = matrix.get_indices()[j];
437 if col == fine_point {
438 a_ii = matrix.get_data()[j];
439 break;
440 }
441 }
442
443 if SparseElement::is_zero(&a_ii) {
444 return Ok(weights);
445 }
446
447 let mut total_weight = T::sparse_zero();
450 let mut temp_weights = Vec::new();
451
452 for &coarse_neighbor in &coarse_neighbors {
453 let mut a_ij = T::sparse_zero();
454 for j in row_start..row_end {
455 let col = matrix.get_indices()[j];
456 if col == coarse_neighbor {
457 a_ij = matrix.get_data()[j];
458 break;
459 }
460 }
461
462 if !SparseElement::is_zero(&a_ij) {
463 let weight = -a_ij / a_ii;
464 temp_weights.push(weight);
465 total_weight = total_weight + weight;
466 } else {
467 temp_weights.push(T::sparse_zero());
468 }
469 }
470
471 if !SparseElement::is_zero(&total_weight) {
473 for (i, &coarse_idx) in coarse_weights.iter().enumerate() {
474 let normalized_weight = temp_weights[i] / total_weight;
475 if !SparseElement::is_zero(&normalized_weight) {
476 weights.push((coarse_idx, normalized_weight));
477 }
478 }
479 }
480
481 Ok(weights)
482 }
483
484 pub fn apply(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
496 let (n, _) = self.operators[0].shape();
497 if b.len() != n {
498 return Err(SparseError::DimensionMismatch {
499 expected: n,
500 found: b.len(),
501 });
502 }
503
504 let mut x = Array1::zeros(n);
505 self.mg_cycle(&mut x, b, 0)?;
506 Ok(x)
507 }
508
509 fn mg_cycle(&self, x: &mut Array1<T>, b: &ArrayView1<T>, level: usize) -> SparseResult<()> {
511 if level == self.num_levels - 1 {
512 self.coarse_solve(x, b, level)?;
514 return Ok(());
515 }
516
517 let matrix = &self.operators[level];
518
519 for _ in 0..self.options.pre_smooth_steps {
521 self.smooth(x, b, matrix)?;
522 }
523
524 let ax = matrix_vector_multiply(matrix, &x.view())?;
526 let residual = b - &ax;
527
528 let restriction = &self.restrictions[level];
530 let coarse_residual = matrix_vector_multiply(restriction, &residual.view())?;
531
532 let coarse_size = coarse_residual.len();
534 let mut coarse_correction = Array1::zeros(coarse_size);
535
536 match self.options.cycle_type {
537 CycleType::V => {
538 self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
539 }
540 CycleType::W => {
541 self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
543 self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
544 }
545 CycleType::F => {
546 self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
548 }
549 }
550
551 let prolongation = &self.prolongations[level];
553 let fine_correction = matrix_vector_multiply(prolongation, &coarse_correction.view())?;
554
555 for i in 0..x.len() {
557 x[i] = x[i] + fine_correction[i];
558 }
559
560 for _ in 0..self.options.post_smooth_steps {
562 self.smooth(x, b, matrix)?;
563 }
564
565 Ok(())
566 }
567
568 fn smooth(
570 &self,
571 x: &mut Array1<T>,
572 b: &ArrayView1<T>,
573 matrix: &CsrArray<T>,
574 ) -> SparseResult<()> {
575 match self.options.smoother {
576 SmootherType::GaussSeidel => self.gauss_seidel_smooth(x, b, matrix),
577 SmootherType::Jacobi => self.jacobi_smooth(x, b, matrix),
578 SmootherType::SOR => self.sor_smooth(x, b, matrix, T::from(1.2).unwrap()),
579 }
580 }
581
582 fn gauss_seidel_smooth(
584 &self,
585 x: &mut Array1<T>,
586 b: &ArrayView1<T>,
587 matrix: &CsrArray<T>,
588 ) -> SparseResult<()> {
589 let n = x.len();
590
591 for i in 0..n {
592 let row_start = matrix.get_indptr()[i];
593 let row_end = matrix.get_indptr()[i + 1];
594
595 let mut sum = T::sparse_zero();
596 let mut diag_val = T::sparse_zero();
597
598 for j in row_start..row_end {
599 let col = matrix.get_indices()[j];
600 let val = matrix.get_data()[j];
601
602 if col == i {
603 diag_val = val;
604 } else {
605 sum = sum + val * x[col];
606 }
607 }
608
609 if !SparseElement::is_zero(&diag_val) {
610 x[i] = (b[i] - sum) / diag_val;
611 }
612 }
613
614 Ok(())
615 }
616
617 fn jacobi_smooth(
619 &self,
620 x: &mut Array1<T>,
621 b: &ArrayView1<T>,
622 matrix: &CsrArray<T>,
623 ) -> SparseResult<()> {
624 let n = x.len();
625 let mut x_new = x.clone();
626
627 for i in 0..n {
628 let row_start = matrix.get_indptr()[i];
629 let row_end = matrix.get_indptr()[i + 1];
630
631 let mut sum = T::sparse_zero();
632 let mut diag_val = T::sparse_zero();
633
634 for j in row_start..row_end {
635 let col = matrix.get_indices()[j];
636 let val = matrix.get_data()[j];
637
638 if col == i {
639 diag_val = val;
640 } else {
641 sum = sum + val * x[col];
642 }
643 }
644
645 if !SparseElement::is_zero(&diag_val) {
646 x_new[i] = (b[i] - sum) / diag_val;
647 }
648 }
649
650 *x = x_new;
651 Ok(())
652 }
653
654 fn sor_smooth(
656 &self,
657 x: &mut Array1<T>,
658 b: &ArrayView1<T>,
659 matrix: &CsrArray<T>,
660 omega: T,
661 ) -> SparseResult<()> {
662 let n = x.len();
663
664 for i in 0..n {
665 let row_start = matrix.get_indptr()[i];
666 let row_end = matrix.get_indptr()[i + 1];
667
668 let mut sum = T::sparse_zero();
669 let mut diag_val = T::sparse_zero();
670
671 for j in row_start..row_end {
672 let col = matrix.get_indices()[j];
673 let val = matrix.get_data()[j];
674
675 if col == i {
676 diag_val = val;
677 } else {
678 sum = sum + val * x[col];
679 }
680 }
681
682 if !SparseElement::is_zero(&diag_val) {
683 let x_gs = (b[i] - sum) / diag_val;
684 x[i] = (T::sparse_one() - omega) * x[i] + omega * x_gs;
685 }
686 }
687
688 Ok(())
689 }
690
691 fn coarse_solve(&self, x: &mut Array1<T>, b: &ArrayView1<T>, level: usize) -> SparseResult<()> {
693 let matrix = &self.operators[level];
695
696 for _ in 0..10 {
697 self.gauss_seidel_smooth(x, b, matrix)?;
698 }
699
700 Ok(())
701 }
702
703 pub fn num_levels(&self) -> usize {
705 self.num_levels
706 }
707
708 pub fn level_size(&self, level: usize) -> Option<(usize, usize)> {
710 if level < self.num_levels {
711 Some(self.operators[level].shape())
712 } else {
713 None
714 }
715 }
716}
717
718#[allow(dead_code)]
720fn matrix_vector_multiply<T>(matrix: &CsrArray<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
721where
722 T: Float + SparseElement + Debug + Copy + 'static,
723{
724 let (rows, cols) = matrix.shape();
725 if x.len() != cols {
726 return Err(SparseError::DimensionMismatch {
727 expected: cols,
728 found: x.len(),
729 });
730 }
731
732 let mut result = Array1::zeros(rows);
733
734 for i in 0..rows {
735 for j in matrix.get_indptr()[i]..matrix.get_indptr()[i + 1] {
736 let col = matrix.get_indices()[j];
737 let val = matrix.get_data()[j];
738 result[i] = result[i] + val * x[col];
739 }
740 }
741
742 Ok(result)
743}
744
745#[cfg(test)]
746mod tests {
747 use super::*;
748 use crate::csr_array::CsrArray;
749
750 #[test]
751 fn test_amg_preconditioner_creation() {
752 let rows = vec![0, 0, 1, 1, 2, 2];
754 let cols = vec![0, 1, 0, 1, 1, 2];
755 let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
756 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
757
758 let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).unwrap();
759
760 assert!(amg.num_levels() >= 1);
761 assert_eq!(amg.level_size(0), Some((3, 3)));
762 }
763
764 #[test]
765 fn test_amg_apply() {
766 let rows = vec![0, 1, 2];
768 let cols = vec![0, 1, 2];
769 let data = vec![2.0, 3.0, 4.0];
770 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
771
772 let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).unwrap();
773
774 let b = Array1::from_vec(vec![2.0, 3.0, 4.0]);
775 let x = amg.apply(&b.view()).unwrap();
776
777 assert!(x[0] > 0.5 && x[0] < 1.5);
779 assert!(x[1] > 0.5 && x[1] < 1.5);
780 assert!(x[2] > 0.5 && x[2] < 1.5);
781 }
782
783 #[test]
784 fn test_amg_options() {
785 let options = AMGOptions {
786 max_levels: 5,
787 theta: 0.5,
788 smoother: SmootherType::Jacobi,
789 cycle_type: CycleType::W,
790 ..Default::default()
791 };
792
793 assert_eq!(options.max_levels, 5);
794 assert_eq!(options.theta, 0.5);
795 assert!(matches!(options.smoother, SmootherType::Jacobi));
796 assert!(matches!(options.cycle_type, CycleType::W));
797 }
798
799 #[test]
800 fn test_gauss_seidel_smoother() {
801 let rows = vec![0, 0, 1, 1, 2, 2];
802 let cols = vec![0, 1, 0, 1, 1, 2];
803 let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
804 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
805
806 let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).unwrap();
807
808 let mut x = Array1::from_vec(vec![0.0, 0.0, 0.0]);
809 let b = Array1::from_vec(vec![1.0, 1.0, 1.0]);
810
811 amg.gauss_seidel_smooth(&mut x, &b.view(), &matrix).unwrap();
813
814 assert!(x.iter().any(|&val| val.abs() > 1e-10));
816 }
817
818 #[test]
819 fn test_enhanced_amg_coarsening() {
820 let rows = vec![0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4];
822 let cols = vec![0, 1, 0, 1, 2, 1, 2, 3, 2, 3, 3, 4, 0];
823 let data = vec![
824 4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0, -1.0, 4.0, -1.0,
825 ];
826 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (5, 5), false).unwrap();
827
828 let options = AMGOptions {
829 theta: 0.25, ..Default::default()
831 };
832
833 let amg = AMGPreconditioner::new(&matrix, options).unwrap();
834
835 assert!(amg.num_levels() >= 1);
837
838 let b = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
840 let x = amg.apply(&b.view()).unwrap();
841
842 assert_eq!(x.len(), 5);
844
845 assert!(x.iter().any(|&val| val.abs() > 1e-10));
847 }
848
849 #[test]
850 fn test_strong_connection_detection() {
851 let rows = vec![0, 0, 1, 1, 2, 2];
852 let cols = vec![0, 1, 0, 1, 1, 2];
853 let data = vec![4.0, -2.0, -2.0, 4.0, -2.0, 4.0];
854 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
855
856 let options = AMGOptions {
857 theta: 0.25,
858 ..Default::default()
859 };
860 let amg = AMGPreconditioner::new(&matrix, options).unwrap();
861
862 let strong_connections = amg.detect_strong_connections(&matrix).unwrap();
863
864 assert!(!strong_connections[0].is_empty());
866 assert!(!strong_connections[1].is_empty());
867
868 if strong_connections[0].contains(&1) {
870 assert!(strong_connections[1].contains(&0));
871 }
872 }
873}