1use crate::csr_array::CsrArray;
8use crate::error::{SparseError, SparseResult};
9use crate::sparray::SparseArray;
10use scirs2_core::ndarray::{Array1, ArrayView1};
11use scirs2_core::numeric::Float;
12use scirs2_core::SparseElement;
13use std::collections::HashMap;
14use std::fmt::Debug;
15
16#[derive(Debug, Clone)]
18pub struct AMGOptions {
19 pub max_levels: usize,
21 pub theta: f64,
23 pub max_coarse_size: usize,
25 pub interpolation: InterpolationType,
27 pub smoother: SmootherType,
29 pub pre_smooth_steps: usize,
31 pub post_smooth_steps: usize,
33 pub cycle_type: CycleType,
35}
36
37impl Default for AMGOptions {
38 fn default() -> Self {
39 Self {
40 max_levels: 10,
41 theta: 0.25,
42 max_coarse_size: 50,
43 interpolation: InterpolationType::Classical,
44 smoother: SmootherType::GaussSeidel,
45 pre_smooth_steps: 1,
46 post_smooth_steps: 1,
47 cycle_type: CycleType::V,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Copy)]
54pub enum InterpolationType {
55 Classical,
57 Direct,
59 Standard,
61}
62
63#[derive(Debug, Clone, Copy)]
65pub enum SmootherType {
66 GaussSeidel,
68 Jacobi,
70 SOR,
72}
73
74#[derive(Debug, Clone, Copy)]
76pub enum CycleType {
77 V,
79 W,
81 F,
83}
84
85#[derive(Debug)]
87pub struct AMGPreconditioner<T>
88where
89 T: Float + SparseElement + Debug + Copy + 'static,
90{
91 operators: Vec<CsrArray<T>>,
93 prolongations: Vec<CsrArray<T>>,
95 restrictions: Vec<CsrArray<T>>,
97 options: AMGOptions,
99 num_levels: usize,
101}
102
103impl<T> AMGPreconditioner<T>
104where
105 T: Float + SparseElement + Debug + Copy + 'static,
106{
107 pub fn new(matrix: &CsrArray<T>, options: AMGOptions) -> SparseResult<Self> {
134 let mut amg = AMGPreconditioner {
135 operators: vec![matrix.clone()],
136 prolongations: Vec::new(),
137 restrictions: Vec::new(),
138 options,
139 num_levels: 1,
140 };
141
142 amg.build_hierarchy()?;
144
145 Ok(amg)
146 }
147
148 fn build_hierarchy(&mut self) -> SparseResult<()> {
150 let mut level = 0;
151
152 while level < self.options.max_levels - 1 {
153 let currentmatrix = &self.operators[level];
154 let (rows, _) = currentmatrix.shape();
155
156 if rows <= self.options.max_coarse_size {
158 break;
159 }
160
161 let (coarsematrix, prolongation, restriction) = self.coarsen_level(currentmatrix)?;
163
164 let (coarse_rows, _) = coarsematrix.shape();
166 if coarse_rows >= rows {
167 break;
169 }
170
171 self.operators.push(coarsematrix);
172 self.prolongations.push(prolongation);
173 self.restrictions.push(restriction);
174 self.num_levels += 1;
175 level += 1;
176 }
177
178 Ok(())
179 }
180
181 fn coarsen_level(
183 &self,
184 matrix: &CsrArray<T>,
185 ) -> SparseResult<(CsrArray<T>, CsrArray<T>, CsrArray<T>)> {
186 let (n, _) = matrix.shape();
187
188 let strong_connections = self.detect_strong_connections(matrix)?;
190
191 let (c_points, f_points) = self.classical_cf_splitting(matrix, &strong_connections)?;
193
194 let mut fine_to_coarse = HashMap::new();
196 for (coarse_idx, &fine_idx) in c_points.iter().enumerate() {
197 fine_to_coarse.insert(fine_idx, coarse_idx);
198 }
199
200 let coarse_size = c_points.len();
201
202 let prolongation = self.build_prolongation(matrix, &fine_to_coarse, coarse_size)?;
204
205 let restriction_box = prolongation.transpose()?;
207 let restriction = restriction_box
208 .as_any()
209 .downcast_ref::<CsrArray<T>>()
210 .ok_or_else(|| {
211 SparseError::ValueError("Failed to downcast restriction to CsrArray".to_string())
212 })?
213 .clone();
214
215 let temp_box = restriction.dot(matrix)?;
217 let temp = temp_box
218 .as_any()
219 .downcast_ref::<CsrArray<T>>()
220 .ok_or_else(|| {
221 SparseError::ValueError("Failed to downcast temp to CsrArray".to_string())
222 })?;
223 let coarsematrix_box = temp.dot(&prolongation)?;
224 let coarsematrix = coarsematrix_box
225 .as_any()
226 .downcast_ref::<CsrArray<T>>()
227 .ok_or_else(|| {
228 SparseError::ValueError("Failed to downcast coarsematrix to CsrArray".to_string())
229 })?
230 .clone();
231
232 Ok((coarsematrix, prolongation, restriction))
233 }
234
235 fn detect_strong_connections(&self, matrix: &CsrArray<T>) -> SparseResult<Vec<Vec<usize>>> {
238 let (n, _) = matrix.shape();
239 let mut strong_connections = vec![Vec::new(); n];
240
241 #[allow(clippy::needless_range_loop)]
242 for i in 0..n {
243 let row_start = matrix.get_indptr()[i];
244 let row_end = matrix.get_indptr()[i + 1];
245
246 let mut max_off_diag = T::sparse_zero();
248 for j in row_start..row_end {
249 let col = matrix.get_indices()[j];
250 if col != i {
251 let val = matrix.get_data()[j].abs();
252 if val > max_off_diag {
253 max_off_diag = val;
254 }
255 }
256 }
257
258 let threshold = T::from(self.options.theta).expect("Operation failed") * max_off_diag;
260 for j in row_start..row_end {
261 let col = matrix.get_indices()[j];
262 if col != i {
263 let val = matrix.get_data()[j].abs();
264 if val >= threshold {
265 strong_connections[i].push(col);
266 }
267 }
268 }
269 }
270
271 Ok(strong_connections)
272 }
273
274 fn classical_cf_splitting(
276 &self,
277 matrix: &CsrArray<T>,
278 strong_connections: &[Vec<usize>],
279 ) -> SparseResult<(Vec<usize>, Vec<usize>)> {
280 let (n, _) = matrix.shape();
281
282 let mut influence = vec![0; n];
284 for i in 0..n {
285 influence[i] = strong_connections[i].len();
286 }
287
288 let mut point_type = vec![0; n];
290 let mut c_points = Vec::new();
291 let mut f_points = Vec::new();
292
293 let mut sorted_points: Vec<usize> = (0..n).collect();
295 sorted_points.sort_by(|&a, &b| influence[b].cmp(&influence[a]));
296
297 for &i in &sorted_points {
298 if point_type[i] != 0 {
299 continue; }
301
302 let mut needs_coarse = false;
304
305 for &j in &strong_connections[i] {
307 if point_type[j] == 2 {
308 let mut has_coarse_interp = false;
311 for &k in &strong_connections[j] {
312 if point_type[k] == 1 {
313 has_coarse_interp = true;
315 break;
316 }
317 }
318 if !has_coarse_interp {
319 needs_coarse = true;
320 break;
321 }
322 }
323 }
324
325 if needs_coarse || influence[i] > 2 {
326 point_type[i] = 1;
328 c_points.push(i);
329
330 for &j in &strong_connections[i] {
332 if point_type[j] == 0 {
333 point_type[j] = 2;
334 f_points.push(j);
335 }
336 }
337 }
338 }
339
340 #[allow(clippy::needless_range_loop)]
342 for i in 0..n {
343 if point_type[i] == 0 {
344 point_type[i] = 2;
345 f_points.push(i);
346 }
347 }
348
349 Ok((c_points, f_points))
350 }
351
352 fn build_prolongation(
354 &self,
355 matrix: &CsrArray<T>,
356 fine_to_coarse: &HashMap<usize, usize>,
357 coarse_size: usize,
358 ) -> SparseResult<CsrArray<T>> {
359 let (n, _) = matrix.shape();
360 let mut prolongation_data = Vec::new();
361 let mut prolongation_indices = Vec::new();
362 let mut prolongation_indptr = vec![0];
363
364 let strong_connections = self.detect_strong_connections(matrix)?;
366
367 #[allow(clippy::needless_range_loop)]
368 for i in 0..n {
369 if let Some(&coarse_idx) = fine_to_coarse.get(&i) {
370 prolongation_data.push(T::sparse_one());
372 prolongation_indices.push(coarse_idx);
373 } else {
374 let interp_weights = self.compute_interpolation_weights(
376 i,
377 matrix,
378 &strong_connections[i],
379 fine_to_coarse,
380 )?;
381
382 if interp_weights.is_empty() {
383 prolongation_data.push(T::sparse_one());
385 prolongation_indices.push(0);
386 } else {
387 for (coarse_idx, weight) in interp_weights {
389 prolongation_data.push(weight);
390 prolongation_indices.push(coarse_idx);
391 }
392 }
393 }
394 prolongation_indptr.push(prolongation_data.len());
395 }
396
397 CsrArray::new(
398 prolongation_data.into(),
399 prolongation_indptr.into(),
400 prolongation_indices.into(),
401 (n, coarse_size),
402 )
403 }
404
405 fn compute_interpolation_weights(
407 &self,
408 fine_point: usize,
409 matrix: &CsrArray<T>,
410 strong_neighbors: &[usize],
411 fine_to_coarse: &HashMap<usize, usize>,
412 ) -> SparseResult<Vec<(usize, T)>> {
413 let mut weights = Vec::new();
414
415 let mut coarse_neighbors = Vec::new();
417 let mut coarse_weights = Vec::new();
418
419 for &neighbor in strong_neighbors {
420 if let Some(&coarse_idx) = fine_to_coarse.get(&neighbor) {
421 coarse_neighbors.push(neighbor);
422 coarse_weights.push(coarse_idx);
423 }
424 }
425
426 if coarse_neighbors.is_empty() {
427 return Ok(weights);
428 }
429
430 let mut a_ii = T::sparse_zero();
432 let row_start = matrix.get_indptr()[fine_point];
433 let row_end = matrix.get_indptr()[fine_point + 1];
434
435 for j in row_start..row_end {
436 let col = matrix.get_indices()[j];
437 if col == fine_point {
438 a_ii = matrix.get_data()[j];
439 break;
440 }
441 }
442
443 if SparseElement::is_zero(&a_ii) {
444 return Ok(weights);
445 }
446
447 let mut total_weight = T::sparse_zero();
450 let mut temp_weights = Vec::new();
451
452 for &coarse_neighbor in &coarse_neighbors {
453 let mut a_ij = T::sparse_zero();
454 for j in row_start..row_end {
455 let col = matrix.get_indices()[j];
456 if col == coarse_neighbor {
457 a_ij = matrix.get_data()[j];
458 break;
459 }
460 }
461
462 if !SparseElement::is_zero(&a_ij) {
463 let weight = -a_ij / a_ii;
464 temp_weights.push(weight);
465 total_weight = total_weight + weight;
466 } else {
467 temp_weights.push(T::sparse_zero());
468 }
469 }
470
471 if !SparseElement::is_zero(&total_weight) {
473 for (i, &coarse_idx) in coarse_weights.iter().enumerate() {
474 let normalized_weight = temp_weights[i] / total_weight;
475 if !SparseElement::is_zero(&normalized_weight) {
476 weights.push((coarse_idx, normalized_weight));
477 }
478 }
479 }
480
481 Ok(weights)
482 }
483
484 pub fn apply(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
496 let (n, _) = self.operators[0].shape();
497 if b.len() != n {
498 return Err(SparseError::DimensionMismatch {
499 expected: n,
500 found: b.len(),
501 });
502 }
503
504 let mut x = Array1::zeros(n);
505 self.mg_cycle(&mut x, b, 0)?;
506 Ok(x)
507 }
508
509 fn mg_cycle(&self, x: &mut Array1<T>, b: &ArrayView1<T>, level: usize) -> SparseResult<()> {
511 if level == self.num_levels - 1 {
512 self.coarse_solve(x, b, level)?;
514 return Ok(());
515 }
516
517 let matrix = &self.operators[level];
518
519 for _ in 0..self.options.pre_smooth_steps {
521 self.smooth(x, b, matrix)?;
522 }
523
524 let ax = matrix_vector_multiply(matrix, &x.view())?;
526 let residual = b - &ax;
527
528 let restriction = &self.restrictions[level];
530 let coarse_residual = matrix_vector_multiply(restriction, &residual.view())?;
531
532 let coarse_size = coarse_residual.len();
534 let mut coarse_correction = Array1::zeros(coarse_size);
535
536 match self.options.cycle_type {
537 CycleType::V => {
538 self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
539 }
540 CycleType::W => {
541 self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
543 self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
544 }
545 CycleType::F => {
546 self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
548 }
549 }
550
551 let prolongation = &self.prolongations[level];
553 let fine_correction = matrix_vector_multiply(prolongation, &coarse_correction.view())?;
554
555 for i in 0..x.len() {
557 x[i] = x[i] + fine_correction[i];
558 }
559
560 for _ in 0..self.options.post_smooth_steps {
562 self.smooth(x, b, matrix)?;
563 }
564
565 Ok(())
566 }
567
568 fn smooth(
570 &self,
571 x: &mut Array1<T>,
572 b: &ArrayView1<T>,
573 matrix: &CsrArray<T>,
574 ) -> SparseResult<()> {
575 match self.options.smoother {
576 SmootherType::GaussSeidel => self.gauss_seidel_smooth(x, b, matrix),
577 SmootherType::Jacobi => self.jacobi_smooth(x, b, matrix),
578 SmootherType::SOR => {
579 self.sor_smooth(x, b, matrix, T::from(1.2).expect("Operation failed"))
580 }
581 }
582 }
583
584 fn gauss_seidel_smooth(
586 &self,
587 x: &mut Array1<T>,
588 b: &ArrayView1<T>,
589 matrix: &CsrArray<T>,
590 ) -> SparseResult<()> {
591 let n = x.len();
592
593 for i in 0..n {
594 let row_start = matrix.get_indptr()[i];
595 let row_end = matrix.get_indptr()[i + 1];
596
597 let mut sum = T::sparse_zero();
598 let mut diag_val = T::sparse_zero();
599
600 for j in row_start..row_end {
601 let col = matrix.get_indices()[j];
602 let val = matrix.get_data()[j];
603
604 if col == i {
605 diag_val = val;
606 } else {
607 sum = sum + val * x[col];
608 }
609 }
610
611 if !SparseElement::is_zero(&diag_val) {
612 x[i] = (b[i] - sum) / diag_val;
613 }
614 }
615
616 Ok(())
617 }
618
619 fn jacobi_smooth(
621 &self,
622 x: &mut Array1<T>,
623 b: &ArrayView1<T>,
624 matrix: &CsrArray<T>,
625 ) -> SparseResult<()> {
626 let n = x.len();
627 let mut x_new = x.clone();
628
629 for i in 0..n {
630 let row_start = matrix.get_indptr()[i];
631 let row_end = matrix.get_indptr()[i + 1];
632
633 let mut sum = T::sparse_zero();
634 let mut diag_val = T::sparse_zero();
635
636 for j in row_start..row_end {
637 let col = matrix.get_indices()[j];
638 let val = matrix.get_data()[j];
639
640 if col == i {
641 diag_val = val;
642 } else {
643 sum = sum + val * x[col];
644 }
645 }
646
647 if !SparseElement::is_zero(&diag_val) {
648 x_new[i] = (b[i] - sum) / diag_val;
649 }
650 }
651
652 *x = x_new;
653 Ok(())
654 }
655
656 fn sor_smooth(
658 &self,
659 x: &mut Array1<T>,
660 b: &ArrayView1<T>,
661 matrix: &CsrArray<T>,
662 omega: T,
663 ) -> SparseResult<()> {
664 let n = x.len();
665
666 for i in 0..n {
667 let row_start = matrix.get_indptr()[i];
668 let row_end = matrix.get_indptr()[i + 1];
669
670 let mut sum = T::sparse_zero();
671 let mut diag_val = T::sparse_zero();
672
673 for j in row_start..row_end {
674 let col = matrix.get_indices()[j];
675 let val = matrix.get_data()[j];
676
677 if col == i {
678 diag_val = val;
679 } else {
680 sum = sum + val * x[col];
681 }
682 }
683
684 if !SparseElement::is_zero(&diag_val) {
685 let x_gs = (b[i] - sum) / diag_val;
686 x[i] = (T::sparse_one() - omega) * x[i] + omega * x_gs;
687 }
688 }
689
690 Ok(())
691 }
692
693 fn coarse_solve(&self, x: &mut Array1<T>, b: &ArrayView1<T>, level: usize) -> SparseResult<()> {
695 let matrix = &self.operators[level];
697
698 for _ in 0..10 {
699 self.gauss_seidel_smooth(x, b, matrix)?;
700 }
701
702 Ok(())
703 }
704
705 pub fn num_levels(&self) -> usize {
707 self.num_levels
708 }
709
710 pub fn level_size(&self, level: usize) -> Option<(usize, usize)> {
712 if level < self.num_levels {
713 Some(self.operators[level].shape())
714 } else {
715 None
716 }
717 }
718}
719
720#[allow(dead_code)]
722fn matrix_vector_multiply<T>(matrix: &CsrArray<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
723where
724 T: Float + SparseElement + Debug + Copy + 'static,
725{
726 let (rows, cols) = matrix.shape();
727 if x.len() != cols {
728 return Err(SparseError::DimensionMismatch {
729 expected: cols,
730 found: x.len(),
731 });
732 }
733
734 let mut result = Array1::zeros(rows);
735
736 for i in 0..rows {
737 for j in matrix.get_indptr()[i]..matrix.get_indptr()[i + 1] {
738 let col = matrix.get_indices()[j];
739 let val = matrix.get_data()[j];
740 result[i] = result[i] + val * x[col];
741 }
742 }
743
744 Ok(result)
745}
746
747#[cfg(test)]
748mod tests {
749 use super::*;
750 use crate::csr_array::CsrArray;
751
752 #[test]
753 fn test_amg_preconditioner_creation() {
754 let rows = vec![0, 0, 1, 1, 2, 2];
756 let cols = vec![0, 1, 0, 1, 1, 2];
757 let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
758 let matrix =
759 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
760
761 let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).expect("Operation failed");
762
763 assert!(amg.num_levels() >= 1);
764 assert_eq!(amg.level_size(0), Some((3, 3)));
765 }
766
767 #[test]
768 fn test_amg_apply() {
769 let rows = vec![0, 1, 2];
771 let cols = vec![0, 1, 2];
772 let data = vec![2.0, 3.0, 4.0];
773 let matrix =
774 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
775
776 let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).expect("Operation failed");
777
778 let b = Array1::from_vec(vec![2.0, 3.0, 4.0]);
779 let x = amg.apply(&b.view()).expect("Operation failed");
780
781 assert!(x[0] > 0.5 && x[0] < 1.5);
783 assert!(x[1] > 0.5 && x[1] < 1.5);
784 assert!(x[2] > 0.5 && x[2] < 1.5);
785 }
786
787 #[test]
788 fn test_amg_options() {
789 let options = AMGOptions {
790 max_levels: 5,
791 theta: 0.5,
792 smoother: SmootherType::Jacobi,
793 cycle_type: CycleType::W,
794 ..Default::default()
795 };
796
797 assert_eq!(options.max_levels, 5);
798 assert_eq!(options.theta, 0.5);
799 assert!(matches!(options.smoother, SmootherType::Jacobi));
800 assert!(matches!(options.cycle_type, CycleType::W));
801 }
802
803 #[test]
804 fn test_gauss_seidel_smoother() {
805 let rows = vec![0, 0, 1, 1, 2, 2];
806 let cols = vec![0, 1, 0, 1, 1, 2];
807 let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
808 let matrix =
809 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
810
811 let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).expect("Operation failed");
812
813 let mut x = Array1::from_vec(vec![0.0, 0.0, 0.0]);
814 let b = Array1::from_vec(vec![1.0, 1.0, 1.0]);
815
816 amg.gauss_seidel_smooth(&mut x, &b.view(), &matrix)
818 .expect("Operation failed");
819
820 assert!(x.iter().any(|&val| val.abs() > 1e-10));
822 }
823
824 #[test]
825 fn test_enhanced_amg_coarsening() {
826 let rows = vec![0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4];
828 let cols = vec![0, 1, 0, 1, 2, 1, 2, 3, 2, 3, 3, 4, 0];
829 let data = vec![
830 4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0, -1.0, 4.0, -1.0,
831 ];
832 let matrix =
833 CsrArray::from_triplets(&rows, &cols, &data, (5, 5), false).expect("Operation failed");
834
835 let options = AMGOptions {
836 theta: 0.25, ..Default::default()
838 };
839
840 let amg = AMGPreconditioner::new(&matrix, options).expect("Operation failed");
841
842 assert!(amg.num_levels() >= 1);
844
845 let b = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
847 let x = amg.apply(&b.view()).expect("Operation failed");
848
849 assert_eq!(x.len(), 5);
851
852 assert!(x.iter().any(|&val| val.abs() > 1e-10));
854 }
855
856 #[test]
857 fn test_strong_connection_detection() {
858 let rows = vec![0, 0, 1, 1, 2, 2];
859 let cols = vec![0, 1, 0, 1, 1, 2];
860 let data = vec![4.0, -2.0, -2.0, 4.0, -2.0, 4.0];
861 let matrix =
862 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
863
864 let options = AMGOptions {
865 theta: 0.25,
866 ..Default::default()
867 };
868 let amg = AMGPreconditioner::new(&matrix, options).expect("Operation failed");
869
870 let strong_connections = amg
871 .detect_strong_connections(&matrix)
872 .expect("Operation failed");
873
874 assert!(!strong_connections[0].is_empty());
876 assert!(!strong_connections[1].is_empty());
877
878 if strong_connections[0].contains(&1) {
880 assert!(strong_connections[1].contains(&0));
881 }
882 }
883}