scirs2_sparse/linalg/
amg.rs

1//! Algebraic Multigrid (AMG) preconditioner for sparse linear systems
2//!
3//! AMG is a powerful preconditioner for solving large sparse linear systems,
4//! particularly effective for systems arising from discretizations of
5//! elliptic PDEs and other problems with nice geometric structure.
6
7use crate::csr_array::CsrArray;
8use crate::error::{SparseError, SparseResult};
9use crate::sparray::SparseArray;
10use scirs2_core::ndarray::{Array1, ArrayView1};
11use scirs2_core::numeric::Float;
12use scirs2_core::SparseElement;
13use std::collections::HashMap;
14use std::fmt::Debug;
15
16/// Options for the AMG preconditioner
17#[derive(Debug, Clone)]
18pub struct AMGOptions {
19    /// Maximum number of levels in the multigrid hierarchy
20    pub max_levels: usize,
21    /// Strong connection threshold for coarsening (typically 0.25-0.5)
22    pub theta: f64,
23    /// Maximum size of coarse grid before switching to direct solver
24    pub max_coarse_size: usize,
25    /// Interpolation method
26    pub interpolation: InterpolationType,
27    /// Smoother type
28    pub smoother: SmootherType,
29    /// Number of pre-smoothing steps
30    pub pre_smooth_steps: usize,
31    /// Number of post-smoothing steps
32    pub post_smooth_steps: usize,
33    /// Cycle type (V-cycle, W-cycle, etc.)
34    pub cycle_type: CycleType,
35}
36
37impl Default for AMGOptions {
38    fn default() -> Self {
39        Self {
40            max_levels: 10,
41            theta: 0.25,
42            max_coarse_size: 50,
43            interpolation: InterpolationType::Classical,
44            smoother: SmootherType::GaussSeidel,
45            pre_smooth_steps: 1,
46            post_smooth_steps: 1,
47            cycle_type: CycleType::V,
48        }
49    }
50}
51
52/// Interpolation methods for AMG
53#[derive(Debug, Clone, Copy)]
54pub enum InterpolationType {
55    /// Classical Ruge-Stuben interpolation
56    Classical,
57    /// Direct interpolation
58    Direct,
59    /// Standard interpolation
60    Standard,
61}
62
63/// Smoother types for AMG
64#[derive(Debug, Clone, Copy)]
65pub enum SmootherType {
66    /// Gauss-Seidel smoother
67    GaussSeidel,
68    /// Jacobi smoother
69    Jacobi,
70    /// SOR smoother
71    SOR,
72}
73
74/// Cycle types for AMG
75#[derive(Debug, Clone, Copy)]
76pub enum CycleType {
77    /// V-cycle
78    V,
79    /// W-cycle
80    W,
81    /// F-cycle
82    F,
83}
84
85/// AMG preconditioner implementation
86#[derive(Debug)]
87pub struct AMGPreconditioner<T>
88where
89    T: Float + SparseElement + Debug + Copy + 'static,
90{
91    /// Matrices at each level
92    operators: Vec<CsrArray<T>>,
93    /// Prolongation operators (coarse to fine)
94    prolongations: Vec<CsrArray<T>>,
95    /// Restriction operators (fine to coarse)
96    restrictions: Vec<CsrArray<T>>,
97    /// AMG options
98    options: AMGOptions,
99    /// Number of levels in the hierarchy
100    num_levels: usize,
101}
102
103impl<T> AMGPreconditioner<T>
104where
105    T: Float + SparseElement + Debug + Copy + 'static,
106{
107    /// Create a new AMG preconditioner from a sparse matrix
108    ///
109    /// # Arguments
110    ///
111    /// * `matrix` - The coefficient matrix
112    /// * `options` - AMG options
113    ///
114    /// # Returns
115    ///
116    /// A new AMG preconditioner
117    ///
118    /// # Example
119    ///
120    /// ```rust
121    /// use scirs2_sparse::csr_array::CsrArray;
122    /// use scirs2_sparse::linalg::{AMGPreconditioner, AMGOptions};
123    ///
124    /// // Create a simple matrix
125    /// let rows = vec![0, 0, 1, 1, 2, 2];
126    /// let cols = vec![0, 1, 0, 1, 1, 2];
127    /// let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
128    /// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
129    ///
130    /// // Create AMG preconditioner
131    /// let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).unwrap();
132    /// ```
133    pub fn new(matrix: &CsrArray<T>, options: AMGOptions) -> SparseResult<Self> {
134        let mut amg = AMGPreconditioner {
135            operators: vec![matrix.clone()],
136            prolongations: Vec::new(),
137            restrictions: Vec::new(),
138            options,
139            num_levels: 1,
140        };
141
142        // Build the multigrid hierarchy
143        amg.build_hierarchy()?;
144
145        Ok(amg)
146    }
147
148    /// Build the multigrid hierarchy
149    fn build_hierarchy(&mut self) -> SparseResult<()> {
150        let mut level = 0;
151
152        while level < self.options.max_levels - 1 {
153            let currentmatrix = &self.operators[level];
154            let (rows, _) = currentmatrix.shape();
155
156            // Stop if matrix is small enough
157            if rows <= self.options.max_coarse_size {
158                break;
159            }
160
161            // Coarsen the matrix
162            let (coarsematrix, prolongation, restriction) = self.coarsen_level(currentmatrix)?;
163
164            // Check if coarsening was successful
165            let (coarse_rows, _) = coarsematrix.shape();
166            if coarse_rows >= rows {
167                // Coarsening didn't reduce the problem size significantly
168                break;
169            }
170
171            self.operators.push(coarsematrix);
172            self.prolongations.push(prolongation);
173            self.restrictions.push(restriction);
174            self.num_levels += 1;
175            level += 1;
176        }
177
178        Ok(())
179    }
180
181    /// Coarsen a single level using Ruge-Stuben algebraic coarsening
182    fn coarsen_level(
183        &self,
184        matrix: &CsrArray<T>,
185    ) -> SparseResult<(CsrArray<T>, CsrArray<T>, CsrArray<T>)> {
186        let (n, _) = matrix.shape();
187
188        // Step 1: Detect strong connections
189        let strong_connections = self.detect_strong_connections(matrix)?;
190
191        // Step 2: Perform C/F splitting using classical Ruge-Stuben algorithm
192        let (c_points, f_points) = self.classical_cf_splitting(matrix, &strong_connections)?;
193
194        // Step 3: Build coarse point mapping
195        let mut fine_to_coarse = HashMap::new();
196        for (coarse_idx, &fine_idx) in c_points.iter().enumerate() {
197            fine_to_coarse.insert(fine_idx, coarse_idx);
198        }
199
200        let coarse_size = c_points.len();
201
202        // Build prolongation operator (interpolation)
203        let prolongation = self.build_prolongation(matrix, &fine_to_coarse, coarse_size)?;
204
205        // Build restriction operator (typically transpose of prolongation)
206        let restriction_box = prolongation.transpose()?;
207        let restriction = restriction_box
208            .as_any()
209            .downcast_ref::<CsrArray<T>>()
210            .ok_or_else(|| {
211                SparseError::ValueError("Failed to downcast restriction to CsrArray".to_string())
212            })?
213            .clone();
214
215        // Build coarse matrix: A_coarse = R * A * P
216        let temp_box = restriction.dot(matrix)?;
217        let temp = temp_box
218            .as_any()
219            .downcast_ref::<CsrArray<T>>()
220            .ok_or_else(|| {
221                SparseError::ValueError("Failed to downcast temp to CsrArray".to_string())
222            })?;
223        let coarsematrix_box = temp.dot(&prolongation)?;
224        let coarsematrix = coarsematrix_box
225            .as_any()
226            .downcast_ref::<CsrArray<T>>()
227            .ok_or_else(|| {
228                SparseError::ValueError("Failed to downcast coarsematrix to CsrArray".to_string())
229            })?
230            .clone();
231
232        Ok((coarsematrix, prolongation, restriction))
233    }
234
235    /// Detect strong connections in the matrix
236    /// A connection i -> j is strong if |a_ij| >= theta * max_k(|a_ik|) for k != i
237    fn detect_strong_connections(&self, matrix: &CsrArray<T>) -> SparseResult<Vec<Vec<usize>>> {
238        let (n, _) = matrix.shape();
239        let mut strong_connections = vec![Vec::new(); n];
240
241        #[allow(clippy::needless_range_loop)]
242        for i in 0..n {
243            let row_start = matrix.get_indptr()[i];
244            let row_end = matrix.get_indptr()[i + 1];
245
246            // Find maximum off-diagonal magnitude in this row
247            let mut max_off_diag = T::sparse_zero();
248            for j in row_start..row_end {
249                let col = matrix.get_indices()[j];
250                if col != i {
251                    let val = matrix.get_data()[j].abs();
252                    if val > max_off_diag {
253                        max_off_diag = val;
254                    }
255                }
256            }
257
258            // Identify strong connections
259            let threshold = T::from(self.options.theta).unwrap() * max_off_diag;
260            for j in row_start..row_end {
261                let col = matrix.get_indices()[j];
262                if col != i {
263                    let val = matrix.get_data()[j].abs();
264                    if val >= threshold {
265                        strong_connections[i].push(col);
266                    }
267                }
268            }
269        }
270
271        Ok(strong_connections)
272    }
273
274    /// Classical Ruge-Stuben C/F splitting algorithm
275    fn classical_cf_splitting(
276        &self,
277        matrix: &CsrArray<T>,
278        strong_connections: &[Vec<usize>],
279    ) -> SparseResult<(Vec<usize>, Vec<usize>)> {
280        let (n, _) = matrix.shape();
281
282        // Count strong _connections for each point (influence measure)
283        let mut influence = vec![0; n];
284        for i in 0..n {
285            influence[i] = strong_connections[i].len();
286        }
287
288        // Track point types: 0 = undecided, 1 = C-point, 2 = F-point
289        let mut point_type = vec![0; n];
290        let mut c_points = Vec::new();
291        let mut f_points = Vec::new();
292
293        // Sort points by influence (high influence points become C-points first)
294        let mut sorted_points: Vec<usize> = (0..n).collect();
295        sorted_points.sort_by(|&a, &b| influence[b].cmp(&influence[a]));
296
297        for &i in &sorted_points {
298            if point_type[i] != 0 {
299                continue; // Already processed
300            }
301
302            // Check if this point needs to be a C-point
303            let mut needs_coarse = false;
304
305            // If this point has strong F-point neighbors without coarse interpolatory set
306            for &j in &strong_connections[i] {
307                if point_type[j] == 2 {
308                    // F-point
309                    // Check if F-point j has a coarse interpolatory set
310                    let mut has_coarse_interp = false;
311                    for &k in &strong_connections[j] {
312                        if point_type[k] == 1 {
313                            // C-point
314                            has_coarse_interp = true;
315                            break;
316                        }
317                    }
318                    if !has_coarse_interp {
319                        needs_coarse = true;
320                        break;
321                    }
322                }
323            }
324
325            if needs_coarse || influence[i] > 2 {
326                // Make this a C-point
327                point_type[i] = 1;
328                c_points.push(i);
329
330                // Make strongly connected neighbors F-points
331                for &j in &strong_connections[i] {
332                    if point_type[j] == 0 {
333                        point_type[j] = 2;
334                        f_points.push(j);
335                    }
336                }
337            }
338        }
339
340        // Assign remaining undecided points as F-points
341        #[allow(clippy::needless_range_loop)]
342        for i in 0..n {
343            if point_type[i] == 0 {
344                point_type[i] = 2;
345                f_points.push(i);
346            }
347        }
348
349        Ok((c_points, f_points))
350    }
351
352    /// Build prolongation (interpolation) operator using algebraic interpolation
353    fn build_prolongation(
354        &self,
355        matrix: &CsrArray<T>,
356        fine_to_coarse: &HashMap<usize, usize>,
357        coarse_size: usize,
358    ) -> SparseResult<CsrArray<T>> {
359        let (n, _) = matrix.shape();
360        let mut prolongation_data = Vec::new();
361        let mut prolongation_indices = Vec::new();
362        let mut prolongation_indptr = vec![0];
363
364        // Detect strong connections for interpolation
365        let strong_connections = self.detect_strong_connections(matrix)?;
366
367        #[allow(clippy::needless_range_loop)]
368        for i in 0..n {
369            if let Some(&coarse_idx) = fine_to_coarse.get(&i) {
370                // Direct injection for _coarse points
371                prolongation_data.push(T::sparse_one());
372                prolongation_indices.push(coarse_idx);
373            } else {
374                // Algebraic interpolation for fine points
375                let interp_weights = self.compute_interpolation_weights(
376                    i,
377                    matrix,
378                    &strong_connections[i],
379                    fine_to_coarse,
380                )?;
381
382                if interp_weights.is_empty() {
383                    // Fallback: direct injection to first _coarse point
384                    prolongation_data.push(T::sparse_one());
385                    prolongation_indices.push(0);
386                } else {
387                    // Add interpolation weights
388                    for (coarse_idx, weight) in interp_weights {
389                        prolongation_data.push(weight);
390                        prolongation_indices.push(coarse_idx);
391                    }
392                }
393            }
394            prolongation_indptr.push(prolongation_data.len());
395        }
396
397        CsrArray::new(
398            prolongation_data.into(),
399            prolongation_indptr.into(),
400            prolongation_indices.into(),
401            (n, coarse_size),
402        )
403    }
404
405    /// Compute interpolation weights for a fine point using classical interpolation
406    fn compute_interpolation_weights(
407        &self,
408        fine_point: usize,
409        matrix: &CsrArray<T>,
410        strong_neighbors: &[usize],
411        fine_to_coarse: &HashMap<usize, usize>,
412    ) -> SparseResult<Vec<(usize, T)>> {
413        let mut weights = Vec::new();
414
415        // Find _coarse _neighbors that are strongly connected
416        let mut coarse_neighbors = Vec::new();
417        let mut coarse_weights = Vec::new();
418
419        for &neighbor in strong_neighbors {
420            if let Some(&coarse_idx) = fine_to_coarse.get(&neighbor) {
421                coarse_neighbors.push(neighbor);
422                coarse_weights.push(coarse_idx);
423            }
424        }
425
426        if coarse_neighbors.is_empty() {
427            return Ok(weights);
428        }
429
430        // Get the diagonal entry for fine _point
431        let mut a_ii = T::sparse_zero();
432        let row_start = matrix.get_indptr()[fine_point];
433        let row_end = matrix.get_indptr()[fine_point + 1];
434
435        for j in row_start..row_end {
436            let col = matrix.get_indices()[j];
437            if col == fine_point {
438                a_ii = matrix.get_data()[j];
439                break;
440            }
441        }
442
443        if SparseElement::is_zero(&a_ii) {
444            return Ok(weights);
445        }
446
447        // Compute interpolation weights using classical formula
448        // w_j = -a_ij / a_ii for _coarse _neighbors j
449        let mut total_weight = T::sparse_zero();
450        let mut temp_weights = Vec::new();
451
452        for &coarse_neighbor in &coarse_neighbors {
453            let mut a_ij = T::sparse_zero();
454            for j in row_start..row_end {
455                let col = matrix.get_indices()[j];
456                if col == coarse_neighbor {
457                    a_ij = matrix.get_data()[j];
458                    break;
459                }
460            }
461
462            if !SparseElement::is_zero(&a_ij) {
463                let weight = -a_ij / a_ii;
464                temp_weights.push(weight);
465                total_weight = total_weight + weight;
466            } else {
467                temp_weights.push(T::sparse_zero());
468            }
469        }
470
471        // Normalize weights to sum to 1
472        if !SparseElement::is_zero(&total_weight) {
473            for (i, &coarse_idx) in coarse_weights.iter().enumerate() {
474                let normalized_weight = temp_weights[i] / total_weight;
475                if !SparseElement::is_zero(&normalized_weight) {
476                    weights.push((coarse_idx, normalized_weight));
477                }
478            }
479        }
480
481        Ok(weights)
482    }
483
484    /// Apply the AMG preconditioner
485    ///
486    /// Solves the system M * x = b approximately, where M is the preconditioner
487    ///
488    /// # Arguments
489    ///
490    /// * `b` - Right-hand side vector
491    ///
492    /// # Returns
493    ///
494    /// Approximate solution x
495    pub fn apply(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
496        let (n, _) = self.operators[0].shape();
497        if b.len() != n {
498            return Err(SparseError::DimensionMismatch {
499                expected: n,
500                found: b.len(),
501            });
502        }
503
504        let mut x = Array1::zeros(n);
505        self.mg_cycle(&mut x, b, 0)?;
506        Ok(x)
507    }
508
509    /// Perform one multigrid cycle
510    fn mg_cycle(&self, x: &mut Array1<T>, b: &ArrayView1<T>, level: usize) -> SparseResult<()> {
511        if level == self.num_levels - 1 {
512            // Coarsest level - solve directly (simplified)
513            self.coarse_solve(x, b, level)?;
514            return Ok(());
515        }
516
517        let matrix = &self.operators[level];
518
519        // Pre-smoothing
520        for _ in 0..self.options.pre_smooth_steps {
521            self.smooth(x, b, matrix)?;
522        }
523
524        // Compute residual
525        let ax = matrix_vector_multiply(matrix, &x.view())?;
526        let residual = b - &ax;
527
528        // Restrict residual to coarse grid
529        let restriction = &self.restrictions[level];
530        let coarse_residual = matrix_vector_multiply(restriction, &residual.view())?;
531
532        // Solve on coarse grid
533        let coarse_size = coarse_residual.len();
534        let mut coarse_correction = Array1::zeros(coarse_size);
535
536        match self.options.cycle_type {
537            CycleType::V => {
538                self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
539            }
540            CycleType::W => {
541                // Two recursive calls for W-cycle
542                self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
543                self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
544            }
545            CycleType::F => {
546                // Full multigrid - not implemented here
547                self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
548            }
549        }
550
551        // Prolongate correction to fine grid
552        let prolongation = &self.prolongations[level];
553        let fine_correction = matrix_vector_multiply(prolongation, &coarse_correction.view())?;
554
555        // Add correction
556        for i in 0..x.len() {
557            x[i] = x[i] + fine_correction[i];
558        }
559
560        // Post-smoothing
561        for _ in 0..self.options.post_smooth_steps {
562            self.smooth(x, b, matrix)?;
563        }
564
565        Ok(())
566    }
567
568    /// Apply smoother
569    fn smooth(
570        &self,
571        x: &mut Array1<T>,
572        b: &ArrayView1<T>,
573        matrix: &CsrArray<T>,
574    ) -> SparseResult<()> {
575        match self.options.smoother {
576            SmootherType::GaussSeidel => self.gauss_seidel_smooth(x, b, matrix),
577            SmootherType::Jacobi => self.jacobi_smooth(x, b, matrix),
578            SmootherType::SOR => self.sor_smooth(x, b, matrix, T::from(1.2).unwrap()),
579        }
580    }
581
582    /// Gauss-Seidel smoother
583    fn gauss_seidel_smooth(
584        &self,
585        x: &mut Array1<T>,
586        b: &ArrayView1<T>,
587        matrix: &CsrArray<T>,
588    ) -> SparseResult<()> {
589        let n = x.len();
590
591        for i in 0..n {
592            let row_start = matrix.get_indptr()[i];
593            let row_end = matrix.get_indptr()[i + 1];
594
595            let mut sum = T::sparse_zero();
596            let mut diag_val = T::sparse_zero();
597
598            for j in row_start..row_end {
599                let col = matrix.get_indices()[j];
600                let val = matrix.get_data()[j];
601
602                if col == i {
603                    diag_val = val;
604                } else {
605                    sum = sum + val * x[col];
606                }
607            }
608
609            if !SparseElement::is_zero(&diag_val) {
610                x[i] = (b[i] - sum) / diag_val;
611            }
612        }
613
614        Ok(())
615    }
616
617    /// Jacobi smoother
618    fn jacobi_smooth(
619        &self,
620        x: &mut Array1<T>,
621        b: &ArrayView1<T>,
622        matrix: &CsrArray<T>,
623    ) -> SparseResult<()> {
624        let n = x.len();
625        let mut x_new = x.clone();
626
627        for i in 0..n {
628            let row_start = matrix.get_indptr()[i];
629            let row_end = matrix.get_indptr()[i + 1];
630
631            let mut sum = T::sparse_zero();
632            let mut diag_val = T::sparse_zero();
633
634            for j in row_start..row_end {
635                let col = matrix.get_indices()[j];
636                let val = matrix.get_data()[j];
637
638                if col == i {
639                    diag_val = val;
640                } else {
641                    sum = sum + val * x[col];
642                }
643            }
644
645            if !SparseElement::is_zero(&diag_val) {
646                x_new[i] = (b[i] - sum) / diag_val;
647            }
648        }
649
650        *x = x_new;
651        Ok(())
652    }
653
654    /// SOR smoother
655    fn sor_smooth(
656        &self,
657        x: &mut Array1<T>,
658        b: &ArrayView1<T>,
659        matrix: &CsrArray<T>,
660        omega: T,
661    ) -> SparseResult<()> {
662        let n = x.len();
663
664        for i in 0..n {
665            let row_start = matrix.get_indptr()[i];
666            let row_end = matrix.get_indptr()[i + 1];
667
668            let mut sum = T::sparse_zero();
669            let mut diag_val = T::sparse_zero();
670
671            for j in row_start..row_end {
672                let col = matrix.get_indices()[j];
673                let val = matrix.get_data()[j];
674
675                if col == i {
676                    diag_val = val;
677                } else {
678                    sum = sum + val * x[col];
679                }
680            }
681
682            if !SparseElement::is_zero(&diag_val) {
683                let x_gs = (b[i] - sum) / diag_val;
684                x[i] = (T::sparse_one() - omega) * x[i] + omega * x_gs;
685            }
686        }
687
688        Ok(())
689    }
690
691    /// Coarse grid solver (simplified direct method)
692    fn coarse_solve(&self, x: &mut Array1<T>, b: &ArrayView1<T>, level: usize) -> SparseResult<()> {
693        // For now, just use a few iterations of Gauss-Seidel
694        let matrix = &self.operators[level];
695
696        for _ in 0..10 {
697            self.gauss_seidel_smooth(x, b, matrix)?;
698        }
699
700        Ok(())
701    }
702
703    /// Get the number of levels in the hierarchy
704    pub fn num_levels(&self) -> usize {
705        self.num_levels
706    }
707
708    /// Get the size of the matrix at a given level
709    pub fn level_size(&self, level: usize) -> Option<(usize, usize)> {
710        if level < self.num_levels {
711            Some(self.operators[level].shape())
712        } else {
713            None
714        }
715    }
716}
717
718/// Helper function for matrix-vector multiplication
719#[allow(dead_code)]
720fn matrix_vector_multiply<T>(matrix: &CsrArray<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
721where
722    T: Float + SparseElement + Debug + Copy + 'static,
723{
724    let (rows, cols) = matrix.shape();
725    if x.len() != cols {
726        return Err(SparseError::DimensionMismatch {
727            expected: cols,
728            found: x.len(),
729        });
730    }
731
732    let mut result = Array1::zeros(rows);
733
734    for i in 0..rows {
735        for j in matrix.get_indptr()[i]..matrix.get_indptr()[i + 1] {
736            let col = matrix.get_indices()[j];
737            let val = matrix.get_data()[j];
738            result[i] = result[i] + val * x[col];
739        }
740    }
741
742    Ok(result)
743}
744
745#[cfg(test)]
746mod tests {
747    use super::*;
748    use crate::csr_array::CsrArray;
749
750    #[test]
751    fn test_amg_preconditioner_creation() {
752        // Create a simple 3x3 matrix
753        let rows = vec![0, 0, 1, 1, 2, 2];
754        let cols = vec![0, 1, 0, 1, 1, 2];
755        let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
756        let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
757
758        let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).unwrap();
759
760        assert!(amg.num_levels() >= 1);
761        assert_eq!(amg.level_size(0), Some((3, 3)));
762    }
763
764    #[test]
765    fn test_amg_apply() {
766        // Create a diagonal system (easy test case)
767        let rows = vec![0, 1, 2];
768        let cols = vec![0, 1, 2];
769        let data = vec![2.0, 3.0, 4.0];
770        let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
771
772        let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).unwrap();
773
774        let b = Array1::from_vec(vec![2.0, 3.0, 4.0]);
775        let x = amg.apply(&b.view()).unwrap();
776
777        // For a diagonal system, AMG should get close to the exact solution [1, 1, 1]
778        assert!(x[0] > 0.5 && x[0] < 1.5);
779        assert!(x[1] > 0.5 && x[1] < 1.5);
780        assert!(x[2] > 0.5 && x[2] < 1.5);
781    }
782
783    #[test]
784    fn test_amg_options() {
785        let options = AMGOptions {
786            max_levels: 5,
787            theta: 0.5,
788            smoother: SmootherType::Jacobi,
789            cycle_type: CycleType::W,
790            ..Default::default()
791        };
792
793        assert_eq!(options.max_levels, 5);
794        assert_eq!(options.theta, 0.5);
795        assert!(matches!(options.smoother, SmootherType::Jacobi));
796        assert!(matches!(options.cycle_type, CycleType::W));
797    }
798
799    #[test]
800    fn test_gauss_seidel_smoother() {
801        let rows = vec![0, 0, 1, 1, 2, 2];
802        let cols = vec![0, 1, 0, 1, 1, 2];
803        let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
804        let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
805
806        let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).unwrap();
807
808        let mut x = Array1::from_vec(vec![0.0, 0.0, 0.0]);
809        let b = Array1::from_vec(vec![1.0, 1.0, 1.0]);
810
811        // Apply one Gauss-Seidel iteration
812        amg.gauss_seidel_smooth(&mut x, &b.view(), &matrix).unwrap();
813
814        // Solution should improve (move away from zero)
815        assert!(x.iter().any(|&val| val.abs() > 1e-10));
816    }
817
818    #[test]
819    fn test_enhanced_amg_coarsening() {
820        // Create a larger test matrix to better test algebraic coarsening
821        let rows = vec![0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4];
822        let cols = vec![0, 1, 0, 1, 2, 1, 2, 3, 2, 3, 3, 4, 0];
823        let data = vec![
824            4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0, -1.0, 4.0, -1.0,
825        ];
826        let matrix = CsrArray::from_triplets(&rows, &cols, &data, (5, 5), false).unwrap();
827
828        let options = AMGOptions {
829            theta: 0.25, // Strong connection threshold
830            ..Default::default()
831        };
832
833        let amg = AMGPreconditioner::new(&matrix, options).unwrap();
834
835        // Should have created a hierarchy
836        assert!(amg.num_levels() >= 1);
837
838        // Test that it can be applied
839        let b = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
840        let x = amg.apply(&b.view()).unwrap();
841
842        // Check that the result has the right size
843        assert_eq!(x.len(), 5);
844
845        // Check that the solution is reasonable (not all zeros)
846        assert!(x.iter().any(|&val| val.abs() > 1e-10));
847    }
848
849    #[test]
850    fn test_strong_connection_detection() {
851        let rows = vec![0, 0, 1, 1, 2, 2];
852        let cols = vec![0, 1, 0, 1, 1, 2];
853        let data = vec![4.0, -2.0, -2.0, 4.0, -2.0, 4.0];
854        let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
855
856        let options = AMGOptions {
857            theta: 0.25,
858            ..Default::default()
859        };
860        let amg = AMGPreconditioner::new(&matrix, options).unwrap();
861
862        let strong_connections = amg.detect_strong_connections(&matrix).unwrap();
863
864        // Each point should have some strong connections
865        assert!(!strong_connections[0].is_empty());
866        assert!(!strong_connections[1].is_empty());
867
868        // Verify strong connections are bidirectional for symmetric matrix
869        if strong_connections[0].contains(&1) {
870            assert!(strong_connections[1].contains(&0));
871        }
872    }
873}