scirs2_sparse/linalg/
amg.rs

1//! Algebraic Multigrid (AMG) preconditioner for sparse linear systems
2//!
3//! AMG is a powerful preconditioner for solving large sparse linear systems,
4//! particularly effective for systems arising from discretizations of
5//! elliptic PDEs and other problems with nice geometric structure.
6
7use crate::csr_array::CsrArray;
8use crate::error::{SparseError, SparseResult};
9use crate::sparray::SparseArray;
10use scirs2_core::ndarray::{Array1, ArrayView1};
11use scirs2_core::numeric::Float;
12use scirs2_core::SparseElement;
13use std::collections::HashMap;
14use std::fmt::Debug;
15
16/// Options for the AMG preconditioner
17#[derive(Debug, Clone)]
18pub struct AMGOptions {
19    /// Maximum number of levels in the multigrid hierarchy
20    pub max_levels: usize,
21    /// Strong connection threshold for coarsening (typically 0.25-0.5)
22    pub theta: f64,
23    /// Maximum size of coarse grid before switching to direct solver
24    pub max_coarse_size: usize,
25    /// Interpolation method
26    pub interpolation: InterpolationType,
27    /// Smoother type
28    pub smoother: SmootherType,
29    /// Number of pre-smoothing steps
30    pub pre_smooth_steps: usize,
31    /// Number of post-smoothing steps
32    pub post_smooth_steps: usize,
33    /// Cycle type (V-cycle, W-cycle, etc.)
34    pub cycle_type: CycleType,
35}
36
37impl Default for AMGOptions {
38    fn default() -> Self {
39        Self {
40            max_levels: 10,
41            theta: 0.25,
42            max_coarse_size: 50,
43            interpolation: InterpolationType::Classical,
44            smoother: SmootherType::GaussSeidel,
45            pre_smooth_steps: 1,
46            post_smooth_steps: 1,
47            cycle_type: CycleType::V,
48        }
49    }
50}
51
52/// Interpolation methods for AMG
53#[derive(Debug, Clone, Copy)]
54pub enum InterpolationType {
55    /// Classical Ruge-Stuben interpolation
56    Classical,
57    /// Direct interpolation
58    Direct,
59    /// Standard interpolation
60    Standard,
61}
62
63/// Smoother types for AMG
64#[derive(Debug, Clone, Copy)]
65pub enum SmootherType {
66    /// Gauss-Seidel smoother
67    GaussSeidel,
68    /// Jacobi smoother
69    Jacobi,
70    /// SOR smoother
71    SOR,
72}
73
74/// Cycle types for AMG
75#[derive(Debug, Clone, Copy)]
76pub enum CycleType {
77    /// V-cycle
78    V,
79    /// W-cycle
80    W,
81    /// F-cycle
82    F,
83}
84
85/// AMG preconditioner implementation
86#[derive(Debug)]
87pub struct AMGPreconditioner<T>
88where
89    T: Float + SparseElement + Debug + Copy + 'static,
90{
91    /// Matrices at each level
92    operators: Vec<CsrArray<T>>,
93    /// Prolongation operators (coarse to fine)
94    prolongations: Vec<CsrArray<T>>,
95    /// Restriction operators (fine to coarse)
96    restrictions: Vec<CsrArray<T>>,
97    /// AMG options
98    options: AMGOptions,
99    /// Number of levels in the hierarchy
100    num_levels: usize,
101}
102
103impl<T> AMGPreconditioner<T>
104where
105    T: Float + SparseElement + Debug + Copy + 'static,
106{
107    /// Create a new AMG preconditioner from a sparse matrix
108    ///
109    /// # Arguments
110    ///
111    /// * `matrix` - The coefficient matrix
112    /// * `options` - AMG options
113    ///
114    /// # Returns
115    ///
116    /// A new AMG preconditioner
117    ///
118    /// # Example
119    ///
120    /// ```rust
121    /// use scirs2_sparse::csr_array::CsrArray;
122    /// use scirs2_sparse::linalg::{AMGPreconditioner, AMGOptions};
123    ///
124    /// // Create a simple matrix
125    /// let rows = vec![0, 0, 1, 1, 2, 2];
126    /// let cols = vec![0, 1, 0, 1, 1, 2];
127    /// let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
128    /// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
129    ///
130    /// // Create AMG preconditioner
131    /// let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).expect("Operation failed");
132    /// ```
133    pub fn new(matrix: &CsrArray<T>, options: AMGOptions) -> SparseResult<Self> {
134        let mut amg = AMGPreconditioner {
135            operators: vec![matrix.clone()],
136            prolongations: Vec::new(),
137            restrictions: Vec::new(),
138            options,
139            num_levels: 1,
140        };
141
142        // Build the multigrid hierarchy
143        amg.build_hierarchy()?;
144
145        Ok(amg)
146    }
147
148    /// Build the multigrid hierarchy
149    fn build_hierarchy(&mut self) -> SparseResult<()> {
150        let mut level = 0;
151
152        while level < self.options.max_levels - 1 {
153            let currentmatrix = &self.operators[level];
154            let (rows, _) = currentmatrix.shape();
155
156            // Stop if matrix is small enough
157            if rows <= self.options.max_coarse_size {
158                break;
159            }
160
161            // Coarsen the matrix
162            let (coarsematrix, prolongation, restriction) = self.coarsen_level(currentmatrix)?;
163
164            // Check if coarsening was successful
165            let (coarse_rows, _) = coarsematrix.shape();
166            if coarse_rows >= rows {
167                // Coarsening didn't reduce the problem size significantly
168                break;
169            }
170
171            self.operators.push(coarsematrix);
172            self.prolongations.push(prolongation);
173            self.restrictions.push(restriction);
174            self.num_levels += 1;
175            level += 1;
176        }
177
178        Ok(())
179    }
180
181    /// Coarsen a single level using Ruge-Stuben algebraic coarsening
182    fn coarsen_level(
183        &self,
184        matrix: &CsrArray<T>,
185    ) -> SparseResult<(CsrArray<T>, CsrArray<T>, CsrArray<T>)> {
186        let (n, _) = matrix.shape();
187
188        // Step 1: Detect strong connections
189        let strong_connections = self.detect_strong_connections(matrix)?;
190
191        // Step 2: Perform C/F splitting using classical Ruge-Stuben algorithm
192        let (c_points, f_points) = self.classical_cf_splitting(matrix, &strong_connections)?;
193
194        // Step 3: Build coarse point mapping
195        let mut fine_to_coarse = HashMap::new();
196        for (coarse_idx, &fine_idx) in c_points.iter().enumerate() {
197            fine_to_coarse.insert(fine_idx, coarse_idx);
198        }
199
200        let coarse_size = c_points.len();
201
202        // Build prolongation operator (interpolation)
203        let prolongation = self.build_prolongation(matrix, &fine_to_coarse, coarse_size)?;
204
205        // Build restriction operator (typically transpose of prolongation)
206        let restriction_box = prolongation.transpose()?;
207        let restriction = restriction_box
208            .as_any()
209            .downcast_ref::<CsrArray<T>>()
210            .ok_or_else(|| {
211                SparseError::ValueError("Failed to downcast restriction to CsrArray".to_string())
212            })?
213            .clone();
214
215        // Build coarse matrix: A_coarse = R * A * P
216        let temp_box = restriction.dot(matrix)?;
217        let temp = temp_box
218            .as_any()
219            .downcast_ref::<CsrArray<T>>()
220            .ok_or_else(|| {
221                SparseError::ValueError("Failed to downcast temp to CsrArray".to_string())
222            })?;
223        let coarsematrix_box = temp.dot(&prolongation)?;
224        let coarsematrix = coarsematrix_box
225            .as_any()
226            .downcast_ref::<CsrArray<T>>()
227            .ok_or_else(|| {
228                SparseError::ValueError("Failed to downcast coarsematrix to CsrArray".to_string())
229            })?
230            .clone();
231
232        Ok((coarsematrix, prolongation, restriction))
233    }
234
235    /// Detect strong connections in the matrix
236    /// A connection i -> j is strong if |a_ij| >= theta * max_k(|a_ik|) for k != i
237    fn detect_strong_connections(&self, matrix: &CsrArray<T>) -> SparseResult<Vec<Vec<usize>>> {
238        let (n, _) = matrix.shape();
239        let mut strong_connections = vec![Vec::new(); n];
240
241        #[allow(clippy::needless_range_loop)]
242        for i in 0..n {
243            let row_start = matrix.get_indptr()[i];
244            let row_end = matrix.get_indptr()[i + 1];
245
246            // Find maximum off-diagonal magnitude in this row
247            let mut max_off_diag = T::sparse_zero();
248            for j in row_start..row_end {
249                let col = matrix.get_indices()[j];
250                if col != i {
251                    let val = matrix.get_data()[j].abs();
252                    if val > max_off_diag {
253                        max_off_diag = val;
254                    }
255                }
256            }
257
258            // Identify strong connections
259            let threshold = T::from(self.options.theta).expect("Operation failed") * max_off_diag;
260            for j in row_start..row_end {
261                let col = matrix.get_indices()[j];
262                if col != i {
263                    let val = matrix.get_data()[j].abs();
264                    if val >= threshold {
265                        strong_connections[i].push(col);
266                    }
267                }
268            }
269        }
270
271        Ok(strong_connections)
272    }
273
274    /// Classical Ruge-Stuben C/F splitting algorithm
275    fn classical_cf_splitting(
276        &self,
277        matrix: &CsrArray<T>,
278        strong_connections: &[Vec<usize>],
279    ) -> SparseResult<(Vec<usize>, Vec<usize>)> {
280        let (n, _) = matrix.shape();
281
282        // Count strong _connections for each point (influence measure)
283        let mut influence = vec![0; n];
284        for i in 0..n {
285            influence[i] = strong_connections[i].len();
286        }
287
288        // Track point types: 0 = undecided, 1 = C-point, 2 = F-point
289        let mut point_type = vec![0; n];
290        let mut c_points = Vec::new();
291        let mut f_points = Vec::new();
292
293        // Sort points by influence (high influence points become C-points first)
294        let mut sorted_points: Vec<usize> = (0..n).collect();
295        sorted_points.sort_by(|&a, &b| influence[b].cmp(&influence[a]));
296
297        for &i in &sorted_points {
298            if point_type[i] != 0 {
299                continue; // Already processed
300            }
301
302            // Check if this point needs to be a C-point
303            let mut needs_coarse = false;
304
305            // If this point has strong F-point neighbors without coarse interpolatory set
306            for &j in &strong_connections[i] {
307                if point_type[j] == 2 {
308                    // F-point
309                    // Check if F-point j has a coarse interpolatory set
310                    let mut has_coarse_interp = false;
311                    for &k in &strong_connections[j] {
312                        if point_type[k] == 1 {
313                            // C-point
314                            has_coarse_interp = true;
315                            break;
316                        }
317                    }
318                    if !has_coarse_interp {
319                        needs_coarse = true;
320                        break;
321                    }
322                }
323            }
324
325            if needs_coarse || influence[i] > 2 {
326                // Make this a C-point
327                point_type[i] = 1;
328                c_points.push(i);
329
330                // Make strongly connected neighbors F-points
331                for &j in &strong_connections[i] {
332                    if point_type[j] == 0 {
333                        point_type[j] = 2;
334                        f_points.push(j);
335                    }
336                }
337            }
338        }
339
340        // Assign remaining undecided points as F-points
341        #[allow(clippy::needless_range_loop)]
342        for i in 0..n {
343            if point_type[i] == 0 {
344                point_type[i] = 2;
345                f_points.push(i);
346            }
347        }
348
349        Ok((c_points, f_points))
350    }
351
352    /// Build prolongation (interpolation) operator using algebraic interpolation
353    fn build_prolongation(
354        &self,
355        matrix: &CsrArray<T>,
356        fine_to_coarse: &HashMap<usize, usize>,
357        coarse_size: usize,
358    ) -> SparseResult<CsrArray<T>> {
359        let (n, _) = matrix.shape();
360        let mut prolongation_data = Vec::new();
361        let mut prolongation_indices = Vec::new();
362        let mut prolongation_indptr = vec![0];
363
364        // Detect strong connections for interpolation
365        let strong_connections = self.detect_strong_connections(matrix)?;
366
367        #[allow(clippy::needless_range_loop)]
368        for i in 0..n {
369            if let Some(&coarse_idx) = fine_to_coarse.get(&i) {
370                // Direct injection for _coarse points
371                prolongation_data.push(T::sparse_one());
372                prolongation_indices.push(coarse_idx);
373            } else {
374                // Algebraic interpolation for fine points
375                let interp_weights = self.compute_interpolation_weights(
376                    i,
377                    matrix,
378                    &strong_connections[i],
379                    fine_to_coarse,
380                )?;
381
382                if interp_weights.is_empty() {
383                    // Fallback: direct injection to first _coarse point
384                    prolongation_data.push(T::sparse_one());
385                    prolongation_indices.push(0);
386                } else {
387                    // Add interpolation weights
388                    for (coarse_idx, weight) in interp_weights {
389                        prolongation_data.push(weight);
390                        prolongation_indices.push(coarse_idx);
391                    }
392                }
393            }
394            prolongation_indptr.push(prolongation_data.len());
395        }
396
397        CsrArray::new(
398            prolongation_data.into(),
399            prolongation_indptr.into(),
400            prolongation_indices.into(),
401            (n, coarse_size),
402        )
403    }
404
405    /// Compute interpolation weights for a fine point using classical interpolation
406    fn compute_interpolation_weights(
407        &self,
408        fine_point: usize,
409        matrix: &CsrArray<T>,
410        strong_neighbors: &[usize],
411        fine_to_coarse: &HashMap<usize, usize>,
412    ) -> SparseResult<Vec<(usize, T)>> {
413        let mut weights = Vec::new();
414
415        // Find _coarse _neighbors that are strongly connected
416        let mut coarse_neighbors = Vec::new();
417        let mut coarse_weights = Vec::new();
418
419        for &neighbor in strong_neighbors {
420            if let Some(&coarse_idx) = fine_to_coarse.get(&neighbor) {
421                coarse_neighbors.push(neighbor);
422                coarse_weights.push(coarse_idx);
423            }
424        }
425
426        if coarse_neighbors.is_empty() {
427            return Ok(weights);
428        }
429
430        // Get the diagonal entry for fine _point
431        let mut a_ii = T::sparse_zero();
432        let row_start = matrix.get_indptr()[fine_point];
433        let row_end = matrix.get_indptr()[fine_point + 1];
434
435        for j in row_start..row_end {
436            let col = matrix.get_indices()[j];
437            if col == fine_point {
438                a_ii = matrix.get_data()[j];
439                break;
440            }
441        }
442
443        if SparseElement::is_zero(&a_ii) {
444            return Ok(weights);
445        }
446
447        // Compute interpolation weights using classical formula
448        // w_j = -a_ij / a_ii for _coarse _neighbors j
449        let mut total_weight = T::sparse_zero();
450        let mut temp_weights = Vec::new();
451
452        for &coarse_neighbor in &coarse_neighbors {
453            let mut a_ij = T::sparse_zero();
454            for j in row_start..row_end {
455                let col = matrix.get_indices()[j];
456                if col == coarse_neighbor {
457                    a_ij = matrix.get_data()[j];
458                    break;
459                }
460            }
461
462            if !SparseElement::is_zero(&a_ij) {
463                let weight = -a_ij / a_ii;
464                temp_weights.push(weight);
465                total_weight = total_weight + weight;
466            } else {
467                temp_weights.push(T::sparse_zero());
468            }
469        }
470
471        // Normalize weights to sum to 1
472        if !SparseElement::is_zero(&total_weight) {
473            for (i, &coarse_idx) in coarse_weights.iter().enumerate() {
474                let normalized_weight = temp_weights[i] / total_weight;
475                if !SparseElement::is_zero(&normalized_weight) {
476                    weights.push((coarse_idx, normalized_weight));
477                }
478            }
479        }
480
481        Ok(weights)
482    }
483
484    /// Apply the AMG preconditioner
485    ///
486    /// Solves the system M * x = b approximately, where M is the preconditioner
487    ///
488    /// # Arguments
489    ///
490    /// * `b` - Right-hand side vector
491    ///
492    /// # Returns
493    ///
494    /// Approximate solution x
495    pub fn apply(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
496        let (n, _) = self.operators[0].shape();
497        if b.len() != n {
498            return Err(SparseError::DimensionMismatch {
499                expected: n,
500                found: b.len(),
501            });
502        }
503
504        let mut x = Array1::zeros(n);
505        self.mg_cycle(&mut x, b, 0)?;
506        Ok(x)
507    }
508
509    /// Perform one multigrid cycle
510    fn mg_cycle(&self, x: &mut Array1<T>, b: &ArrayView1<T>, level: usize) -> SparseResult<()> {
511        if level == self.num_levels - 1 {
512            // Coarsest level - solve directly (simplified)
513            self.coarse_solve(x, b, level)?;
514            return Ok(());
515        }
516
517        let matrix = &self.operators[level];
518
519        // Pre-smoothing
520        for _ in 0..self.options.pre_smooth_steps {
521            self.smooth(x, b, matrix)?;
522        }
523
524        // Compute residual
525        let ax = matrix_vector_multiply(matrix, &x.view())?;
526        let residual = b - &ax;
527
528        // Restrict residual to coarse grid
529        let restriction = &self.restrictions[level];
530        let coarse_residual = matrix_vector_multiply(restriction, &residual.view())?;
531
532        // Solve on coarse grid
533        let coarse_size = coarse_residual.len();
534        let mut coarse_correction = Array1::zeros(coarse_size);
535
536        match self.options.cycle_type {
537            CycleType::V => {
538                self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
539            }
540            CycleType::W => {
541                // Two recursive calls for W-cycle
542                self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
543                self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
544            }
545            CycleType::F => {
546                // Full multigrid - not implemented here
547                self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
548            }
549        }
550
551        // Prolongate correction to fine grid
552        let prolongation = &self.prolongations[level];
553        let fine_correction = matrix_vector_multiply(prolongation, &coarse_correction.view())?;
554
555        // Add correction
556        for i in 0..x.len() {
557            x[i] = x[i] + fine_correction[i];
558        }
559
560        // Post-smoothing
561        for _ in 0..self.options.post_smooth_steps {
562            self.smooth(x, b, matrix)?;
563        }
564
565        Ok(())
566    }
567
568    /// Apply smoother
569    fn smooth(
570        &self,
571        x: &mut Array1<T>,
572        b: &ArrayView1<T>,
573        matrix: &CsrArray<T>,
574    ) -> SparseResult<()> {
575        match self.options.smoother {
576            SmootherType::GaussSeidel => self.gauss_seidel_smooth(x, b, matrix),
577            SmootherType::Jacobi => self.jacobi_smooth(x, b, matrix),
578            SmootherType::SOR => {
579                self.sor_smooth(x, b, matrix, T::from(1.2).expect("Operation failed"))
580            }
581        }
582    }
583
584    /// Gauss-Seidel smoother
585    fn gauss_seidel_smooth(
586        &self,
587        x: &mut Array1<T>,
588        b: &ArrayView1<T>,
589        matrix: &CsrArray<T>,
590    ) -> SparseResult<()> {
591        let n = x.len();
592
593        for i in 0..n {
594            let row_start = matrix.get_indptr()[i];
595            let row_end = matrix.get_indptr()[i + 1];
596
597            let mut sum = T::sparse_zero();
598            let mut diag_val = T::sparse_zero();
599
600            for j in row_start..row_end {
601                let col = matrix.get_indices()[j];
602                let val = matrix.get_data()[j];
603
604                if col == i {
605                    diag_val = val;
606                } else {
607                    sum = sum + val * x[col];
608                }
609            }
610
611            if !SparseElement::is_zero(&diag_val) {
612                x[i] = (b[i] - sum) / diag_val;
613            }
614        }
615
616        Ok(())
617    }
618
619    /// Jacobi smoother
620    fn jacobi_smooth(
621        &self,
622        x: &mut Array1<T>,
623        b: &ArrayView1<T>,
624        matrix: &CsrArray<T>,
625    ) -> SparseResult<()> {
626        let n = x.len();
627        let mut x_new = x.clone();
628
629        for i in 0..n {
630            let row_start = matrix.get_indptr()[i];
631            let row_end = matrix.get_indptr()[i + 1];
632
633            let mut sum = T::sparse_zero();
634            let mut diag_val = T::sparse_zero();
635
636            for j in row_start..row_end {
637                let col = matrix.get_indices()[j];
638                let val = matrix.get_data()[j];
639
640                if col == i {
641                    diag_val = val;
642                } else {
643                    sum = sum + val * x[col];
644                }
645            }
646
647            if !SparseElement::is_zero(&diag_val) {
648                x_new[i] = (b[i] - sum) / diag_val;
649            }
650        }
651
652        *x = x_new;
653        Ok(())
654    }
655
656    /// SOR smoother
657    fn sor_smooth(
658        &self,
659        x: &mut Array1<T>,
660        b: &ArrayView1<T>,
661        matrix: &CsrArray<T>,
662        omega: T,
663    ) -> SparseResult<()> {
664        let n = x.len();
665
666        for i in 0..n {
667            let row_start = matrix.get_indptr()[i];
668            let row_end = matrix.get_indptr()[i + 1];
669
670            let mut sum = T::sparse_zero();
671            let mut diag_val = T::sparse_zero();
672
673            for j in row_start..row_end {
674                let col = matrix.get_indices()[j];
675                let val = matrix.get_data()[j];
676
677                if col == i {
678                    diag_val = val;
679                } else {
680                    sum = sum + val * x[col];
681                }
682            }
683
684            if !SparseElement::is_zero(&diag_val) {
685                let x_gs = (b[i] - sum) / diag_val;
686                x[i] = (T::sparse_one() - omega) * x[i] + omega * x_gs;
687            }
688        }
689
690        Ok(())
691    }
692
693    /// Coarse grid solver (simplified direct method)
694    fn coarse_solve(&self, x: &mut Array1<T>, b: &ArrayView1<T>, level: usize) -> SparseResult<()> {
695        // For now, just use a few iterations of Gauss-Seidel
696        let matrix = &self.operators[level];
697
698        for _ in 0..10 {
699            self.gauss_seidel_smooth(x, b, matrix)?;
700        }
701
702        Ok(())
703    }
704
705    /// Get the number of levels in the hierarchy
706    pub fn num_levels(&self) -> usize {
707        self.num_levels
708    }
709
710    /// Get the size of the matrix at a given level
711    pub fn level_size(&self, level: usize) -> Option<(usize, usize)> {
712        if level < self.num_levels {
713            Some(self.operators[level].shape())
714        } else {
715            None
716        }
717    }
718}
719
720/// Helper function for matrix-vector multiplication
721#[allow(dead_code)]
722fn matrix_vector_multiply<T>(matrix: &CsrArray<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
723where
724    T: Float + SparseElement + Debug + Copy + 'static,
725{
726    let (rows, cols) = matrix.shape();
727    if x.len() != cols {
728        return Err(SparseError::DimensionMismatch {
729            expected: cols,
730            found: x.len(),
731        });
732    }
733
734    let mut result = Array1::zeros(rows);
735
736    for i in 0..rows {
737        for j in matrix.get_indptr()[i]..matrix.get_indptr()[i + 1] {
738            let col = matrix.get_indices()[j];
739            let val = matrix.get_data()[j];
740            result[i] = result[i] + val * x[col];
741        }
742    }
743
744    Ok(result)
745}
746
747#[cfg(test)]
748mod tests {
749    use super::*;
750    use crate::csr_array::CsrArray;
751
752    #[test]
753    fn test_amg_preconditioner_creation() {
754        // Create a simple 3x3 matrix
755        let rows = vec![0, 0, 1, 1, 2, 2];
756        let cols = vec![0, 1, 0, 1, 1, 2];
757        let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
758        let matrix =
759            CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
760
761        let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).expect("Operation failed");
762
763        assert!(amg.num_levels() >= 1);
764        assert_eq!(amg.level_size(0), Some((3, 3)));
765    }
766
767    #[test]
768    fn test_amg_apply() {
769        // Create a diagonal system (easy test case)
770        let rows = vec![0, 1, 2];
771        let cols = vec![0, 1, 2];
772        let data = vec![2.0, 3.0, 4.0];
773        let matrix =
774            CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
775
776        let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).expect("Operation failed");
777
778        let b = Array1::from_vec(vec![2.0, 3.0, 4.0]);
779        let x = amg.apply(&b.view()).expect("Operation failed");
780
781        // For a diagonal system, AMG should get close to the exact solution [1, 1, 1]
782        assert!(x[0] > 0.5 && x[0] < 1.5);
783        assert!(x[1] > 0.5 && x[1] < 1.5);
784        assert!(x[2] > 0.5 && x[2] < 1.5);
785    }
786
787    #[test]
788    fn test_amg_options() {
789        let options = AMGOptions {
790            max_levels: 5,
791            theta: 0.5,
792            smoother: SmootherType::Jacobi,
793            cycle_type: CycleType::W,
794            ..Default::default()
795        };
796
797        assert_eq!(options.max_levels, 5);
798        assert_eq!(options.theta, 0.5);
799        assert!(matches!(options.smoother, SmootherType::Jacobi));
800        assert!(matches!(options.cycle_type, CycleType::W));
801    }
802
803    #[test]
804    fn test_gauss_seidel_smoother() {
805        let rows = vec![0, 0, 1, 1, 2, 2];
806        let cols = vec![0, 1, 0, 1, 1, 2];
807        let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
808        let matrix =
809            CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
810
811        let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).expect("Operation failed");
812
813        let mut x = Array1::from_vec(vec![0.0, 0.0, 0.0]);
814        let b = Array1::from_vec(vec![1.0, 1.0, 1.0]);
815
816        // Apply one Gauss-Seidel iteration
817        amg.gauss_seidel_smooth(&mut x, &b.view(), &matrix)
818            .expect("Operation failed");
819
820        // Solution should improve (move away from zero)
821        assert!(x.iter().any(|&val| val.abs() > 1e-10));
822    }
823
824    #[test]
825    fn test_enhanced_amg_coarsening() {
826        // Create a larger test matrix to better test algebraic coarsening
827        let rows = vec![0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4];
828        let cols = vec![0, 1, 0, 1, 2, 1, 2, 3, 2, 3, 3, 4, 0];
829        let data = vec![
830            4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0, -1.0, 4.0, -1.0,
831        ];
832        let matrix =
833            CsrArray::from_triplets(&rows, &cols, &data, (5, 5), false).expect("Operation failed");
834
835        let options = AMGOptions {
836            theta: 0.25, // Strong connection threshold
837            ..Default::default()
838        };
839
840        let amg = AMGPreconditioner::new(&matrix, options).expect("Operation failed");
841
842        // Should have created a hierarchy
843        assert!(amg.num_levels() >= 1);
844
845        // Test that it can be applied
846        let b = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
847        let x = amg.apply(&b.view()).expect("Operation failed");
848
849        // Check that the result has the right size
850        assert_eq!(x.len(), 5);
851
852        // Check that the solution is reasonable (not all zeros)
853        assert!(x.iter().any(|&val| val.abs() > 1e-10));
854    }
855
856    #[test]
857    fn test_strong_connection_detection() {
858        let rows = vec![0, 0, 1, 1, 2, 2];
859        let cols = vec![0, 1, 0, 1, 1, 2];
860        let data = vec![4.0, -2.0, -2.0, 4.0, -2.0, 4.0];
861        let matrix =
862            CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
863
864        let options = AMGOptions {
865            theta: 0.25,
866            ..Default::default()
867        };
868        let amg = AMGPreconditioner::new(&matrix, options).expect("Operation failed");
869
870        let strong_connections = amg
871            .detect_strong_connections(&matrix)
872            .expect("Operation failed");
873
874        // Each point should have some strong connections
875        assert!(!strong_connections[0].is_empty());
876        assert!(!strong_connections[1].is_empty());
877
878        // Verify strong connections are bidirectional for symmetric matrix
879        if strong_connections[0].contains(&1) {
880            assert!(strong_connections[1].contains(&0));
881        }
882    }
883}